Makhinur commited on
Commit
baba3f4
·
verified ·
1 Parent(s): a33c030

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -0
main.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, File, UploadFile,Form
3
+ from fastapi.responses import FileResponse, StreamingResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from PIL import Image
6
+ import torch
7
+ from io import BytesIO
8
+
9
+ app = FastAPI()
10
+
11
+ model2 = torch.hub.load(
12
+ "AK391/animegan2-pytorch:main",
13
+ "generator",
14
+ pretrained=True,
15
+ device="cpu",
16
+ progress=False
17
+ )
18
+ model1 = torch.hub.load("AK391/animegan2-pytorch:main",
19
+ "generator", pretrained="face_paint_512_v1", device="cpu")
20
+ face2paint = torch.hub.load(
21
+ 'AK391/animegan2-pytorch:main', 'face2paint',
22
+ size=512, device="cpu", side_by_side=False
23
+ )
24
+
25
+ @app.post("/predict/")
26
+ async def predict(
27
+ file: UploadFile = File(...),
28
+ version: str = Form(...)
29
+ ):
30
+ contents = await file.read()
31
+ image = Image.open(BytesIO(contents))
32
+ if version == 'version2':
33
+ out = face2paint(model2, image)
34
+ else:
35
+ out = face2paint(model1, image)
36
+ img_byte_arr = BytesIO()
37
+ out.save(img_byte_arr, format='PNG')
38
+ img_byte_arr.seek(0)
39
+ return StreamingResponse(img_byte_arr, media_type="image/png")
40
+