VT3 / main.py
Ashrafb's picture
Update main.py
6767de4 verified
raw
history blame
1.24 kB
from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np
import cv2
import torch
from vtoonify_model import Model
app = FastAPI()
# Load the model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
exstyle, message = model.load_model("cartoon1")
class ImageRequest(BaseModel):
image_file: UploadFile = File(...)
@app.post("/upload/")
async def toonify_image(image_request: ImageRequest):
image = await image_request.image_file.read()
nparr = np.frombuffer(image, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
aligned_face, instyle, message = model.detect_and_align_image(img, 200, 200, 200, 200) # Hardcoded values
toonified_img, message = model.image_toonify(aligned_face, instyle, exstyle, style_degree=0.5, style_type="cartoon1")
return {"toonified_image": toonified_img, "message": message}
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/AB/index.html", media_type="text/html")