Quoc Bao Bui commited on
Commit
a7eb3c4
·
1 Parent(s): f513867

Add handler

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from typing import Dict, List, Any
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from diffusers import StableDiffusionPipeline
8
+
9
+
10
+ REPO_ID = "runwayml/stable-diffusion-v1-5"
11
+
12
+
13
+ # helper decoder
14
+ def decode_base64_image(image_string):
15
+ base64_image = base64.b64decode(image_string)
16
+ buffer = BytesIO(base64_image)
17
+ return Image.open(buffer)
18
+
19
+
20
+ class EndpointHandler:
21
+ def __init__(self, path=""):
22
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
23
+ revision="fp16")
24
+ self.pipe = self.pipe.to("cuda")
25
+
26
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
27
+ """
28
+ Args:
29
+ data (:obj:):
30
+ includes the input data and the parameters for the inference.
31
+ Return:
32
+ A :obj:`dict`:. base64 encoded image
33
+ """
34
+ prompts = data.pop("prompts", None)
35
+ encoded_image = data.pop("image", None)
36
+ init_image = decode_base64_image(encoded_image)
37
+ init_image.thumbnail((768, 768))
38
+ image = self.pipe(prompts, init_image=init_image).images[0]
39
+
40
+ # encode image as base 64
41
+ buffered = BytesIO()
42
+ image.save(buffered, format="png")
43
+ img_str = base64.b64encode(buffered.getvalue())
44
+
45
+ # post process the prediction
46
+ return {"image": img_str.decode()}