VT3 / main.py
Ashrafb's picture
Update main.py
a805da0 verified
raw
history blame
1.27 kB
from __future__ import annotations
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
import shutil
from io import BytesIO
import torch
from PIL import Image
import argparse
import pathlib
from vtoonify_model import Model
app = FastAPI()
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
@app.post("/upload/")
async def process_image(file: UploadFile = File(...)):
# Save the uploaded image locally
with open("uploaded_image.jpg", "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Load the model (assuming 'cartoon1' is always used)
exstyle, load_info = model.load_model('cartoon1')
# Process the uploaded image
aligned_face, _, input_info = model.detect_and_align_image("uploaded_image.jpg", padding_params=[200, 200, 200, 200])
result_face, output_info = model.image_toonify(aligned_face, exstyle, 'cartoon1', style_degree=0.5)
# Return the processed image
return FileResponse("result_image.jpg")
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/AB/index.html", media_type="text/html")