File size: 2,443 Bytes
5dd6da5
 
 
 
 
 
 
 
9217b6d
5dd6da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9217b6d
 
 
492b3d9
9217b6d
 
492b3d9
5dd6da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Dict, List, Any
from transformers import CLIPModel, CLIPProcessor
from PIL import Image, ImageOps
import base64
import io
import numpy as np
import pandas as pd
import time
import os


# convert a base64 string to image
def base64_to_img(base64_string):
    imgdata = base64.b64decode(base64_string)
    return Image.open(io.BytesIO(imgdata))


# pad image to square
def pad_to_square(img, pad_color=0):
    return ImageOps.pad(img, (224, 224), color=pad_color)


class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the model
        """
        # can we run the base model on a smaller CPU?
        hf_model_path = "openai/clip-vit-large-patch14"
        self.model = CLIPModel.from_pretrained(hf_model_path)
        self.processor = CLIPProcessor.from_pretrained(hf_model_path)

        # load the lyrics embeddings
        self.lyric_embeddings = np.load(
            os.path.join(path, "clip_lyric_normal_embeddings.npy"), allow_pickle=True
        )
        # load the song info
        self.song_info = pd.read_csv(os.path.join(path, "en_song_info.csv"))

    def __call__(self, data: Dict[str, Any]) -> List[float]:
        """
         data args:
              inputs (:obj: `str` | `PIL.Image` | `np.array`)
              kwargs
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        base64_str = data["inputs"]
        start = time.time()
        image = pad_to_square(base64_to_img(base64_str))
        print(f"Image processing time: {time.time() - start}")
        image_embedding = self.get_embedding(image)
        print(f"Embedding time: {time.time() - start}")
        top_10_songs = self.get_top_songs(image_embedding)
        print(f"Total time: {time.time() - start}")
        return top_10_songs.to_dict(orient="records")

    def get_embedding(self, image):
        inputs = self.processor(images=image, return_tensors="pt", padding=True)
        output = self.model.get_image_features(**inputs)
        image_embedding = output[0]
        np_embedding = image_embedding.cpu().detach().numpy().astype(np.float)
        np_embedding /= np.linalg.norm(np_embedding)
        return np_embedding

    def get_top_songs(self, image_embedding):
        scores = self.lyric_embeddings @ image_embedding
        top_10 = np.argsort(scores)[-10:]
        top_10_songs = self.song_info.iloc[top_10]
        return top_10_songs