Pn / main.py
Makhinur's picture
Create main.py
baba3f4 verified
raw
history blame
1.11 kB
from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from io import BytesIO
app = FastAPI()
model2 = torch.hub.load(
"AK391/animegan2-pytorch:main",
"generator",
pretrained=True,
device="cpu",
progress=False
)
model1 = torch.hub.load("AK391/animegan2-pytorch:main",
"generator", pretrained="face_paint_512_v1", device="cpu")
face2paint = torch.hub.load(
'AK391/animegan2-pytorch:main', 'face2paint',
size=512, device="cpu", side_by_side=False
)
@app.post("/predict/")
async def predict(
file: UploadFile = File(...),
version: str = Form(...)
):
contents = await file.read()
image = Image.open(BytesIO(contents))
if version == 'version2':
out = face2paint(model2, image)
else:
out = face2paint(model1, image)
img_byte_arr = BytesIO()
out.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")