xiaowenbin commited on
Commit
aa65a99
·
verified ·
1 Parent(s): 50a041e

Upload mteb_eval_openai.py

Browse files
Files changed (1) hide show
  1. mteb_eval_openai.py +157 -0
mteb_eval_openai.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import hashlib
5
+ import numpy as np
6
+ import requests
7
+
8
+ import logging
9
+ import functools
10
+ import tiktoken
11
+ from mteb import MTEB
12
+ from sentence_transformers import SentenceTransformer
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger("main")
15
+
16
+ all_task_list = ['Classification', 'Clustering', 'Reranking', 'Retrieval', 'STS', 'PairClassification']
17
+ if len(sys.argv) > 1:
18
+ task_list = [t for t in sys.argv[1].split(',') if t in all_task_list]
19
+ else:
20
+ task_list = all_task_list
21
+
22
+ OPENAI_BASE_URL = os.environ.get('OPENAI_BASE_URL', '')
23
+ OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
24
+ EMB_CACHE_DIR = os.environ.get('EMB_CACHE_DIR', '.cache/embs')
25
+ os.makedirs(EMB_CACHE_DIR, exist_ok=True)
26
+
27
+ def uuid_for_text(text):
28
+ return hashlib.md5(text.encode('utf8')).hexdigest()
29
+
30
+ def count_openai_tokens(text, model="text-embedding-3-large"):
31
+ encoding = tiktoken.get_encoding("cl100k_base")
32
+ #encoding = tiktoken.encoding_for_model(model)
33
+ input_ids = encoding.encode(text)
34
+ return len(input_ids)
35
+
36
+ def request_openai_emb(texts, model="text-embedding-3-large",
37
+ base_url='https://api.openai.com', prefix_url='/v1/embeddings',
38
+ timeout=4, retry=3, interval=2, caching=True):
39
+ if isinstance(texts, str):
40
+ texts = [texts]
41
+
42
+ data = []
43
+ if caching:
44
+ for text in texts:
45
+ emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
46
+ if os.path.isfile(emb_file) and os.path.getsize(emb_file) > 0:
47
+ data.append(np.loadtxt(emb_file))
48
+ if len(texts) == len(data):
49
+ return data
50
+
51
+ url = f"{OPENAI_BASE_URL}{prefix_url}" if OPENAI_BASE_URL else f"{base_url}{prefix_url}"
52
+ headers = {
53
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
54
+ "Content-Type": "application/json"
55
+ }
56
+ payload = {"input": texts, "model": model}
57
+
58
+ while retry > 0 and len(data) == 0:
59
+ try:
60
+ r = requests.post(url, headers=headers, json=payload,
61
+ timeout=timeout)
62
+ res = r.json()
63
+ for x in res["data"]:
64
+ data.append(np.array(x["embedding"]))
65
+ except Exception as e:
66
+ print(f"request openai, retry {retry}, error: {e}", file=sys.stderr)
67
+ time.sleep(interval)
68
+ retry -= 1
69
+
70
+ if len(data) != len(texts):
71
+ data = []
72
+
73
+ if caching and len(data) > 0:
74
+ for text, emb in zip(texts, data):
75
+ emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
76
+ np.savetxt(emb_file, emb)
77
+
78
+ return data
79
+
80
+
81
+ class OpenaiEmbModel:
82
+
83
+ def __init__(self, model_name, model_dim, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self.model_name = model_name
86
+ self.model_dim = model_dim
87
+
88
+ def encode(self, sentences, batch_size=32, **kwargs):
89
+ i = 0
90
+ max_tokens = 8000
91
+ batch_tokens = 0
92
+ batch = []
93
+ batch_list = []
94
+ while i < len(sentences):
95
+ num_tokens = count_openai_tokens(sentences[i],
96
+ model=self.model_name)
97
+ if batch_tokens+num_tokens >= max_tokens:
98
+ if batch:
99
+ batch_list.append(batch)
100
+ batch = []
101
+ batch_tokens = 0
102
+ else:
103
+ batch_list.append([sentences[i][:2048]])
104
+ else:
105
+ batch.append(sentences[i])
106
+ batch_tokens += num_tokens
107
+ i += 1
108
+ if batch:
109
+ batch_list.append(batch)
110
+
111
+ #batch_size = min(64, batch_size)
112
+ #
113
+ #for i in range(0, len(sentences), batch_size):
114
+ # batch_texts = sentences[i:i+batch_size]
115
+ # batch_list.append(batch_texts)
116
+
117
+ embs = []
118
+ waiting = 0
119
+ for batch_texts in batch_list:
120
+ batch_embs = request_openai_emb(batch_texts, model=self.model_name,
121
+ caching=True, timeout=120, retry=3, interval=60)
122
+
123
+ #assert len(batch_texts) == len(batch_embs), "The batch of texts and embs DONT match!"
124
+
125
+ if len(batch_texts) == len(batch_embs):
126
+ embs.extend(batch_embs)
127
+ waiting = waiting // 2
128
+ else:
129
+ print(f"The batch of texts and embs DONT match! {len(batch_texts)}:{len(batch_embs)}", file=sys.stderr)
130
+ embs.extend(np.array([[0.0 for j in range(self.model_dim)] for i in range(len(batch_texts))]))
131
+ waiting = 120 if waiting <= 0 else waiting+120
132
+
133
+ if waiting > 3600:
134
+ print(f"Frequently failed, break down!", file=sys.stderr)
135
+ break
136
+ if waiting > 0:
137
+ time.sleep(waiting)
138
+
139
+ return embs
140
+
141
+
142
+ model_name = "text-embedding-3-large"
143
+ model_dim = 3072
144
+ model = OpenaiEmbModel(model_name, model_dim)
145
+
146
+ ######
147
+ # test
148
+ #####
149
+ #embs = model.encode(['全国', '北京'])
150
+ #print(embs)
151
+ #exit()
152
+
153
+ # languages
154
+ task_langs=["zh", "zh-CN"]
155
+
156
+ evaluation = MTEB(task_types=task_list, task_langs=task_langs)
157
+ evaluation.run(model, output_folder=f"results/zh/{model_name.split('/')[-1]}")