VT3 / main.py
Ashrafb's picture
Update main.py
add003d verified
raw
history blame
2.28 kB
from __future__ import annotations
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
import torch
import shutil
import cv2
import numpy as np
import dlib
from torchvision import transforms
import torch.nn.functional as F
from vtoonify_model import Model # Importing the Model class from vtoonify_model.py
import gradio as gr
import pathlib
import sys
sys.path.insert(0, 'vtoonify')
from util import load_psp_standalone, get_video_crop_parameter, tensor2cv2
import torch
import torch.nn as nn
import numpy as np
import dlib
import cv2
from model.vtoonify import VToonify
from model.bisenet.model import BiSeNet
import torch.nn.functional as F
from torchvision import transforms
from model.encoder.align_all_parallel import align_face
import gc
import huggingface_hub
import os
app = FastAPI()
@app.on_event("startup")
async def load_model():
global model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
from fastapi.responses import StreamingResponse
from io import BytesIO
@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
if model is None:
return {"error": "Model not loaded."}
# Save the uploaded image locally
with open("uploaded_image.jpg", "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Process the uploaded image
aligned_face, instyle, message = model.detect_and_align_image("uploaded_image.jpg", top, bottom, left, right)
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
# Convert processed image to bytes
image_bytes = cv2.imencode('.jpg', processed_image)[1].tobytes()
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(image_bytes), media_type="image/jpeg")
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/AB/index.html", media_type="text/html")