tcm03
commited on
Commit
·
1060235
1
Parent(s):
8d4eb6b
Fix bugs
Browse files- handler.py +6 -6
handler.py
CHANGED
@@ -24,10 +24,10 @@ def preprocess_text(text):
|
|
24 |
"""Tokenize text query."""
|
25 |
return tokenize([str(text)])[0].unsqueeze(0).to(device)
|
26 |
|
27 |
-
def get_fused_embedding(sketch_base64, text, model):
|
28 |
"""Fuse sketch and text features into a single embedding."""
|
29 |
with torch.no_grad():
|
30 |
-
sketch_tensor = preprocess_image(sketch_base64)
|
31 |
text_tensor = preprocess_text(text)
|
32 |
|
33 |
sketch_feature = model.encode_sketch(sketch_tensor)
|
@@ -39,9 +39,9 @@ def get_fused_embedding(sketch_base64, text, model):
|
|
39 |
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
|
40 |
return fused_embedding.cpu().numpy().tolist()
|
41 |
|
42 |
-
def get_image_embedding(image_base64, model):
|
43 |
"""Convert base64 encoded image to tensor."""
|
44 |
-
image_tensor = preprocess_image(image_base64)
|
45 |
with torch.no_grad():
|
46 |
image_feature = model.encode_image(image_tensor)
|
47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
@@ -89,13 +89,13 @@ class EndpointHandler:
|
|
89 |
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
|
90 |
|
91 |
# Generate Fused Embedding
|
92 |
-
fused_embedding = get_fused_embedding(sketch_base64, text_query)
|
93 |
return {"embedding": fused_embedding}
|
94 |
elif "image" in inputs:
|
95 |
image_base64 = inputs.get("image", "")
|
96 |
if not image_base64:
|
97 |
return {"error": "Image 'image' (base64) is required input."}
|
98 |
-
embedding = get_image_embedding(image_base64)
|
99 |
return {"embedding": embedding}
|
100 |
else:
|
101 |
return {"error": "Input 'sketch' or 'image' is required."}
|
|
|
24 |
"""Tokenize text query."""
|
25 |
return tokenize([str(text)])[0].unsqueeze(0).to(device)
|
26 |
|
27 |
+
def get_fused_embedding(sketch_base64, text, model, transformer):
|
28 |
"""Fuse sketch and text features into a single embedding."""
|
29 |
with torch.no_grad():
|
30 |
+
sketch_tensor = preprocess_image(sketch_base64, transformer)
|
31 |
text_tensor = preprocess_text(text)
|
32 |
|
33 |
sketch_feature = model.encode_sketch(sketch_tensor)
|
|
|
39 |
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
|
40 |
return fused_embedding.cpu().numpy().tolist()
|
41 |
|
42 |
+
def get_image_embedding(image_base64, model, transformer):
|
43 |
"""Convert base64 encoded image to tensor."""
|
44 |
+
image_tensor = preprocess_image(image_base64, transformer)
|
45 |
with torch.no_grad():
|
46 |
image_feature = model.encode_image(image_tensor)
|
47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
|
|
89 |
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
|
90 |
|
91 |
# Generate Fused Embedding
|
92 |
+
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
|
93 |
return {"embedding": fused_embedding}
|
94 |
elif "image" in inputs:
|
95 |
image_base64 = inputs.get("image", "")
|
96 |
if not image_base64:
|
97 |
return {"error": "Image 'image' (base64) is required input."}
|
98 |
+
embedding = get_image_embedding(image_base64, self.model, self.transform)
|
99 |
return {"embedding": embedding}
|
100 |
else:
|
101 |
return {"error": "Input 'sketch' or 'image' is required."}
|