File size: 2,149 Bytes
af896ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import AutoTokenizer
import sys
sys.path.append("code")
from clip.model import CLIP
# Load Model and Utilities
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIP.from_pretrained("tcm03/tsbir").to(device)
model.eval()
# Preprocessing Functions
from clip.clip import _transform, tokenize
transformer = _transform(model.visual.input_resolution, is_train=False)
def preprocess_image(image_base64):
"""Convert base64 encoded image to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(image_base64, text):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
# Preprocess Inputs
image_tensor = preprocess_image(image_base64)
text_tensor = preprocess_text(text)
# Extract Features
sketch_feature = model.encode_sketch(image_tensor)
text_feature = model.encode_text(text_tensor)
# Normalize Features
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
# Fuse Features
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
# Hugging Face Inference API Entry Point
def infer(inputs):
"""
Inference API entry point.
Inputs:
- 'image': Base64 encoded sketch image.
- 'text': Text query.
"""
image_base64 = inputs.get("image", "")
text_query = inputs.get("text", "")
if not image_base64 or not text_query:
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(image_base64, text_query)
return {"fused_embedding": fused_embedding}
|