from fastapi import FastAPI, File, UploadFile,Form from fastapi.responses import FileResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from PIL import Image import torch from io import BytesIO app = FastAPI() model2 = torch.hub.load( "AK391/animegan2-pytorch:main", "generator", pretrained=True, device="cpu", progress=False ) model1 = torch.hub.load("AK391/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1", device="cpu") face2paint = torch.hub.load( 'AK391/animegan2-pytorch:main', 'face2paint', size=512, device="cpu", side_by_side=False ) @app.post("/predict/") async def predict( file: UploadFile = File(...), version: str = Form(...) ): contents = await file.read() image = Image.open(BytesIO(contents)) if version == 'version2': out = face2paint(model2, image) else: out = face2paint(model1, image) img_byte_arr = BytesIO() out.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/png")