File size: 2,831 Bytes
af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec 5c18c06 af896ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
from PIL import Image
import base64
from io import BytesIO
import json
import sys
sys.path.append("code")
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model/tsbir_model_final.pt"
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
def load_model():
"""Load the model only once."""
global model
if "model" not in globals():
with open(CONFIG_PATH, 'r') as f:
model_info = json.load(f)
model = CLIP(**model_info)
checkpoint = torch.load(MODEL_PATH, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
# Initialize transformer
global transformer
transformer = _transform(model.visual.input_resolution, is_train=False)
print("Model loaded successfully.")
# Preprocessing Functions
def preprocess_image(image_base64):
"""Convert base64 encoded image to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(image_base64, text):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
# Preprocess Inputs
image_tensor = preprocess_image(image_base64)
text_tensor = preprocess_text(text)
# Extract Features
sketch_feature = model.encode_sketch(image_tensor)
text_feature = model.encode_text(text_tensor)
# Normalize Features
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
# Fuse Features
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
# Hugging Face Inference API Entry Point
def infer(inputs):
"""
Inference API entry point.
Inputs:
- 'image': Base64 encoded sketch image.
- 'text': Text query.
"""
load_model() # Ensure the model is loaded once
image_base64 = inputs.get("image", "")
text_query = inputs.get("text", "")
if not image_base64 or not text_query:
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(image_base64, text_query)
return {"fused_embedding": fused_embedding}
|