rbanfield commited on
Commit
8182cb1
·
1 Parent(s): 4dbb20c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -8
handler.py CHANGED
@@ -3,14 +3,13 @@ import base64
3
 
4
  from PIL import Image
5
  import torch
6
- from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
13
- self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
14
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
15
 
16
  def __call__(self, data):
@@ -19,13 +18,15 @@ class EndpointHandler():
19
  image_input = inputs["image"] if "image" in inputs else None
20
 
21
  if text_input:
22
- processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
 
23
  with torch.no_grad():
24
- return self.text_model(**processor).pooler_output.tolist()
25
  elif image_input:
26
  image = Image.open(BytesIO(base64.b64decode(image_input)))
27
- processor = self.processor(images=image, return_tensors="pt").to(device)
 
28
  with torch.no_grad():
29
- return self.image_model(**processor).image_embeds.tolist()
30
  else:
31
- return None
 
3
 
4
  from PIL import Image
5
  import torch
6
+ from transformers import CLIPProcessor, CLIPModel
7
 
8
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
+ self.model = CLIPModel.from_pretrained("rbanfield/clip-vit-large-patch14").to("cpu")
 
13
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
14
 
15
  def __call__(self, data):
 
18
  image_input = inputs["image"] if "image" in inputs else None
19
 
20
  if text_input:
21
+ processor = self.processor(text=text_input, return_tensors="pt", padding=True)
22
+ processor.to("cpu")
23
  with torch.no_grad():
24
+ return self.model.get_text_features(**processor).tolist()
25
  elif image_input:
26
  image = Image.open(BytesIO(base64.b64decode(image_input)))
27
+ processor = self.processor(images=image, return_tensors="pt")
28
+ processor.to("cpu")
29
  with torch.no_grad():
30
+ return self.model.get_image_features(**processor).tolist()
31
  else:
32
+ return None