|
|
|
from fastapi import FastAPI, File, UploadFile,Form |
|
from fastapi.responses import FileResponse, StreamingResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from PIL import Image |
|
import torch |
|
from io import BytesIO |
|
|
|
app = FastAPI() |
|
|
|
model2 = torch.hub.load( |
|
"AK391/animegan2-pytorch:main", |
|
"generator", |
|
pretrained=True, |
|
device="cpu", |
|
progress=False |
|
) |
|
model1 = torch.hub.load("AK391/animegan2-pytorch:main", |
|
"generator", pretrained="face_paint_512_v1", device="cpu") |
|
face2paint = torch.hub.load( |
|
'AK391/animegan2-pytorch:main', 'face2paint', |
|
size=512, device="cpu", side_by_side=False |
|
) |
|
|
|
@app.post("/predict/") |
|
async def predict( |
|
file: UploadFile = File(...), |
|
version: str = Form(...) |
|
): |
|
contents = await file.read() |
|
image = Image.open(BytesIO(contents)) |
|
if version == 'version2': |
|
out = face2paint(model2, image) |
|
else: |
|
out = face2paint(model1, image) |
|
img_byte_arr = BytesIO() |
|
out.save(img_byte_arr, format='PNG') |
|
img_byte_arr.seek(0) |
|
return StreamingResponse(img_byte_arr, media_type="image/png") |
|
|
|
|