File size: 1,813 Bytes
7c33baa
6b79e42
7c33baa
 
b43f043
 
ea1444b
7c33baa
b43f043
7c33baa
b43f043
 
 
 
99f1eb3
 
b43f043
 
 
 
 
 
c0cfc30
61b6cd2
b43f043
 
 
 
7c33baa
 
 
 
 
b43f043
 
7c33baa
 
 
 
 
 
 
99f1eb3
b43f043
99f1eb3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import annotations
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
import torch
import shutil
import cv2
import numpy as np
import dlib
from torchvision import transforms
import torch.nn.functional as F
from vtoonify_model import Model  # Importing the Model class from vtoonify_model.py

app = FastAPI()
model = None

@app.on_event("startup")
async def load_model():
    global model
    model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')

@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
    if model is None:
        return {"error": "Model not loaded."}

    # Save the uploaded image locally
    with open("uploaded_image.jpg", "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    # Process the uploaded image
    aligned_face, instyle, message = model.detect_and_align_image("uploaded_image.jpg", top, bottom, left, right)
    processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')

    # Save the processed image
    with open("result_image.jpg", "wb") as result_buffer:
        result_buffer.write(processed_image)

    # Return the processed image
    return FileResponse("result_image.jpg", media_type="image/jpeg", headers={"Content-Disposition": "attachment; filename=result_image.jpg"})


app.mount("/", StaticFiles(directory="AB", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/AB/index.html", media_type="text/html")