File size: 1,107 Bytes
baba3f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from io import BytesIO
app = FastAPI()
model2 = torch.hub.load(
"AK391/animegan2-pytorch:main",
"generator",
pretrained=True,
device="cpu",
progress=False
)
model1 = torch.hub.load("AK391/animegan2-pytorch:main",
"generator", pretrained="face_paint_512_v1", device="cpu")
face2paint = torch.hub.load(
'AK391/animegan2-pytorch:main', 'face2paint',
size=512, device="cpu", side_by_side=False
)
@app.post("/predict/")
async def predict(
file: UploadFile = File(...),
version: str = Form(...)
):
contents = await file.read()
image = Image.open(BytesIO(contents))
if version == 'version2':
out = face2paint(model2, image)
else:
out = face2paint(model1, image)
img_byte_arr = BytesIO()
out.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
|