grantpitt's picture
copilot renamed unforch
492b3d9
from typing import Dict, List, Any
from transformers import CLIPModel, CLIPProcessor
from PIL import Image, ImageOps
import base64
import io
import numpy as np
import pandas as pd
import time
import os
# convert a base64 string to image
def base64_to_img(base64_string):
imgdata = base64.b64decode(base64_string)
return Image.open(io.BytesIO(imgdata))
# pad image to square
def pad_to_square(img, pad_color=0):
return ImageOps.pad(img, (224, 224), color=pad_color)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model
"""
# can we run the base model on a smaller CPU?
hf_model_path = "openai/clip-vit-large-patch14"
self.model = CLIPModel.from_pretrained(hf_model_path)
self.processor = CLIPProcessor.from_pretrained(hf_model_path)
# load the lyrics embeddings
self.lyric_embeddings = np.load(
os.path.join(path, "clip_lyric_normal_embeddings.npy"), allow_pickle=True
)
# load the song info
self.song_info = pd.read_csv(os.path.join(path, "en_song_info.csv"))
def __call__(self, data: Dict[str, Any]) -> List[float]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
base64_str = data["inputs"]
start = time.time()
image = pad_to_square(base64_to_img(base64_str))
print(f"Image processing time: {time.time() - start}")
image_embedding = self.get_embedding(image)
print(f"Embedding time: {time.time() - start}")
top_10_songs = self.get_top_songs(image_embedding)
print(f"Total time: {time.time() - start}")
return top_10_songs.to_dict(orient="records")
def get_embedding(self, image):
inputs = self.processor(images=image, return_tensors="pt", padding=True)
output = self.model.get_image_features(**inputs)
image_embedding = output[0]
np_embedding = image_embedding.cpu().detach().numpy().astype(np.float)
np_embedding /= np.linalg.norm(np_embedding)
return np_embedding
def get_top_songs(self, image_embedding):
scores = self.lyric_embeddings @ image_embedding
top_10 = np.argsort(scores)[-10:]
top_10_songs = self.song_info.iloc[top_10]
return top_10_songs