File size: 1,107 Bytes
baba3f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from io import BytesIO

app = FastAPI()

model2 = torch.hub.load(
    "AK391/animegan2-pytorch:main",
    "generator",
    pretrained=True,
    device="cpu",
    progress=False
)
model1 = torch.hub.load("AK391/animegan2-pytorch:main",
                        "generator", pretrained="face_paint_512_v1",  device="cpu")
face2paint = torch.hub.load(
    'AK391/animegan2-pytorch:main', 'face2paint',
    size=512, device="cpu", side_by_side=False
)

@app.post("/predict/")
async def predict(
    file: UploadFile = File(...),
    version: str = Form(...)
):
    contents = await file.read()
    image = Image.open(BytesIO(contents))
    if version == 'version2':
        out = face2paint(model2, image)
    else:
        out = face2paint(model1, image)
    img_byte_arr = BytesIO()
    out.save(img_byte_arr, format='PNG')
    img_byte_arr.seek(0)
    return StreamingResponse(img_byte_arr, media_type="image/png")