VT3 / main.py
Ashrafb's picture
Update main.py
870a38f verified
raw
history blame
1.92 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
import shutil
import cv2
import numpy as np
import dlib
from torchvision import transforms
import torch.nn.functional as F
import gradio as gr
import os
import torch
from io import BytesIO
app = FastAPI()
# Load model and necessary components
model = None
def load_model():
global model
from vtoonify_model import Model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon1')
# Define endpoints
@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
global model
if model is None:
load_model()
# Save the uploaded image locally with its original filename
with open("uploaded_image.jpg", "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Read the saved image using OpenCV
frame = cv2.imread("uploaded_image.jpg")
# Convert the image from BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the uploaded image
aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, top, bottom, left, right)
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
# Convert processed image to bytes
image_bytes = cv2.imencode('.jpg', processed_image)[1].tobytes()
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(image_bytes), media_type="image/jpeg")
# Mount static files directory
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
# Define index route
@app.get("/")
def index():
return FileResponse(path="/app/AB/index.html", media_type="text/html")