connersdavis commited on
Commit
eb9d96a
·
1 Parent(s): be39b45

handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -21
handler.py CHANGED
@@ -1,28 +1,36 @@
1
- from typing import Dict, List, Any
2
- from transformers import pipeline
3
 
4
- class PreTrainedPipeline():
 
 
 
 
 
 
 
 
 
 
 
 
5
  def __init__(self, path=""):
6
- self.pipeline = pipeline("text-classification",model=path)
7
- self.holidays = holidays.US()
 
8
 
9
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
10
  """
11
- data args:
12
- inputs (:obj: `str`)
13
- date (:obj: `str`)
14
- Return:
15
- A :obj:`list` | `dict`: will be serialized and returned
16
  """
17
- # get inputs
18
- inputs = data.pop("inputs",data)
19
- date = data.pop("date", None)
20
-
21
- # check if date exists and if it is holiday
22
- if date is not None and date in self.holidays:
23
- return [{"label": "happy", "score": 1}]
24
 
 
 
 
25
 
26
- # run normal prediction
27
- prediction = self.pipeline(inputs)
28
- return prediction
 
 
 
1
 
2
+ from typing import Dict, List, Any
3
+ import torch
4
+ from torch import autocast
5
+ from diffusers import StableDiffusionPipeline
6
+ import base64
7
+ from io import BytesIO
8
+ # set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ if device.type != 'cuda':
12
+ raise ValueError("need to run on GPU")
13
+
14
+ class EndpointHandler():
15
  def __init__(self, path=""):
16
+ # load the optimized model
17
+ self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
18
+ self.pipe = self.pipe.to(device)
19
 
20
+
21
+ def __call__(self, data: Any) -> "PIL.Image":
22
  """
23
+ Args:
24
+ data (:obj:):
25
+ includes the input data and the parameters for the inference.
26
+ Return:
27
+ A :obj:`dict`:. base64 encoded image
28
  """
29
+ inputs = data.pop("inputs", data)
 
 
 
 
 
 
30
 
31
+ # run inference pipeline
32
+ with autocast(device.type):
33
+ image = self.pipe(inputs, guidance_scale=7.5)["sample"][0]
34
 
35
+ # encoding image as base 64 is done by the default toolkit
36
+ return image