tcm03 commited on
Commit
1060235
·
1 Parent(s): 8d4eb6b
Files changed (1) hide show
  1. handler.py +6 -6
handler.py CHANGED
@@ -24,10 +24,10 @@ def preprocess_text(text):
24
  """Tokenize text query."""
25
  return tokenize([str(text)])[0].unsqueeze(0).to(device)
26
 
27
- def get_fused_embedding(sketch_base64, text, model):
28
  """Fuse sketch and text features into a single embedding."""
29
  with torch.no_grad():
30
- sketch_tensor = preprocess_image(sketch_base64)
31
  text_tensor = preprocess_text(text)
32
 
33
  sketch_feature = model.encode_sketch(sketch_tensor)
@@ -39,9 +39,9 @@ def get_fused_embedding(sketch_base64, text, model):
39
  fused_embedding = model.feature_fuse(sketch_feature, text_feature)
40
  return fused_embedding.cpu().numpy().tolist()
41
 
42
- def get_image_embedding(image_base64, model):
43
  """Convert base64 encoded image to tensor."""
44
- image_tensor = preprocess_image(image_base64)
45
  with torch.no_grad():
46
  image_feature = model.encode_image(image_tensor)
47
  image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
@@ -89,13 +89,13 @@ class EndpointHandler:
89
  return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
90
 
91
  # Generate Fused Embedding
92
- fused_embedding = get_fused_embedding(sketch_base64, text_query)
93
  return {"embedding": fused_embedding}
94
  elif "image" in inputs:
95
  image_base64 = inputs.get("image", "")
96
  if not image_base64:
97
  return {"error": "Image 'image' (base64) is required input."}
98
- embedding = get_image_embedding(image_base64)
99
  return {"embedding": embedding}
100
  else:
101
  return {"error": "Input 'sketch' or 'image' is required."}
 
24
  """Tokenize text query."""
25
  return tokenize([str(text)])[0].unsqueeze(0).to(device)
26
 
27
+ def get_fused_embedding(sketch_base64, text, model, transformer):
28
  """Fuse sketch and text features into a single embedding."""
29
  with torch.no_grad():
30
+ sketch_tensor = preprocess_image(sketch_base64, transformer)
31
  text_tensor = preprocess_text(text)
32
 
33
  sketch_feature = model.encode_sketch(sketch_tensor)
 
39
  fused_embedding = model.feature_fuse(sketch_feature, text_feature)
40
  return fused_embedding.cpu().numpy().tolist()
41
 
42
+ def get_image_embedding(image_base64, model, transformer):
43
  """Convert base64 encoded image to tensor."""
44
+ image_tensor = preprocess_image(image_base64, transformer)
45
  with torch.no_grad():
46
  image_feature = model.encode_image(image_tensor)
47
  image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
 
89
  return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
90
 
91
  # Generate Fused Embedding
92
+ fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
93
  return {"embedding": fused_embedding}
94
  elif "image" in inputs:
95
  image_base64 = inputs.get("image", "")
96
  if not image_base64:
97
  return {"error": "Image 'image' (base64) is required input."}
98
+ embedding = get_image_embedding(image_base64, self.model, self.transform)
99
  return {"embedding": embedding}
100
  else:
101
  return {"error": "Input 'sketch' or 'image' is required."}