from typing import Dict, List, Any from PIL import Image import torch import base64 import os from io import BytesIO import json from pathlib import Path CODE_PATH = Path('code/') import sys sys.path.append(str(CODE_PATH)) from clip.model import CLIP from clip.clip import _transform, tokenize device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the pipeline by loading the model. Args: path (str): Path to the directory containing model weights and config. """ model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json") with open(model_config_file, "r") as f: model_info = json.load(f) model_file = os.path.join(path, "model/tsbir_model_final.pt") self.model = CLIP(**model_info) checkpoint = torch.load(model_file, map_location=device) sd = checkpoint["state_dict"] if next(iter(sd.items()))[0].startswith("module"): sd = {k[len("module."):]: v for k, v in sd.items()} self.model.load_state_dict(sd, strict=False) self.model = self.model.to(device).eval() # Preprocessing self.transform = _transform(self.model.visual.input_resolution, is_train=False) def __call__(self, data: Any) -> Dict[str, List[float]]: """ Process the request and return the fused embedding. Args: data (dict): Includes 'image' (base64) and 'text' (str) inputs. Returns: dict: {"fused_embedding": [float, float, ...]} """ # Parse inputs inputs = data.pop("inputs", data) image_base64 = inputs.get("image", "") text_query = inputs.get("text", "") if not image_base64 or not text_query: return {"error": "Both 'image' (base64) and 'text' are required inputs."} # Preprocess the image image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB") image_tensor = self.transform(image).unsqueeze(0).to(device) # Preprocess the text text_tensor = tokenize([str(text_query)])[0].unsqueeze(0).to(device) # Generate features with torch.no_grad(): sketch_feature = self.model.encode_sketch(image_tensor) text_feature = self.model.encode_text(text_tensor) # Normalize features sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True) text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True) # Fuse features fused_embedding = self.model.feature_fuse(sketch_feature, text_feature) return {"fused_embedding": fused_embedding.cpu().numpy().tolist()}