File size: 1,463 Bytes
4bc1e6a
 
 
 
 
 
 
 
 
485cae6
4bc1e6a
0ebbe50
 
4bc1e6a
 
 
 
 
 
 
 
485cae6
4bc1e6a
 
 
485cae6
 
 
 
 
4bc1e6a
485cae6
4bc1e6a
 
485cae6
 
 
 
 
4bc1e6a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os

from typing import Union

from tclogger import logger
from transformers import AutoModel
from numpy.linalg import norm

from configs.envs import ENVS
from configs.constants import AVAILABLE_MODELS

if ENVS["HF_ENDPOINT"]:
    os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]


def cosine_similarity(a, b):
    return (a @ b.T) / (norm(a) * norm(b))


class JinaAIEmbedder:
    def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
        self.model_name = model_name
        self.load_model()

    def check_model_name(self):
        if self.model_name not in AVAILABLE_MODELS:
            self.model_name = AVAILABLE_MODELS[0]
        return True

    def load_model(self):
        self.check_model_name()
        self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)

    def switch_model(self, model_name: str):
        if model_name != self.model_name:
            self.model_name = model_name
            self.load_model()

    def encode(self, text: Union[str, list[str]]):
        if isinstance(text, str):
            text = [text]
        return self.model.encode(text)


if __name__ == "__main__":
    embedder = JinaAIEmbedder()
    text = ["How is the weather today?", "今天天气怎么样?"]
    # text = "How is the weather today?"
    embeddings = embedder.encode(text)
    logger.success(embeddings)
    # print(cosine_similarity(embeddings[0], embeddings[1]))