from typing import Dict, List, Any from transformers import CLIPModel, CLIPProcessor from PIL import Image, ImageOps import base64 import io import numpy as np import pandas as pd import time import os # convert a base64 string to image def base64_to_img(base64_string): imgdata = base64.b64decode(base64_string) return Image.open(io.BytesIO(imgdata)) # pad image to square def pad_to_square(img, pad_color=0): return ImageOps.pad(img, (224, 224), color=pad_color) class EndpointHandler: def __init__(self, path=""): """ Initialize the model """ # can we run the base model on a smaller CPU? hf_model_path = "openai/clip-vit-large-patch14" self.model = CLIPModel.from_pretrained(hf_model_path) self.processor = CLIPProcessor.from_pretrained(hf_model_path) # load the lyrics embeddings self.lyric_embeddings = np.load( os.path.join(path, "clip_lyric_normal_embeddings.npy"), allow_pickle=True ) # load the song info self.song_info = pd.read_csv(os.path.join(path, "en_song_info.csv")) def __call__(self, data: Dict[str, Any]) -> List[float]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ base64_str = data["inputs"] start = time.time() image = pad_to_square(base64_to_img(base64_str)) print(f"Image processing time: {time.time() - start}") image_embedding = self.get_embedding(image) print(f"Embedding time: {time.time() - start}") top_10_songs = self.get_top_songs(image_embedding) print(f"Total time: {time.time() - start}") return top_10_songs.to_dict(orient="records") def get_embedding(self, image): inputs = self.processor(images=image, return_tensors="pt", padding=True) output = self.model.get_image_features(**inputs) image_embedding = output[0] np_embedding = image_embedding.cpu().detach().numpy().astype(np.float) np_embedding /= np.linalg.norm(np_embedding) return np_embedding def get_top_songs(self, image_embedding): scores = self.lyric_embeddings @ image_embedding top_10 = np.argsort(scores)[-10:] top_10_songs = self.song_info.iloc[top_10] return top_10_songs