Spaces:
Runtime error
Runtime error
Commit
·
411e7e6
1
Parent(s):
0856ae9
update summary generation for new models
Browse files- src/backend/model_operations.py +58 -22
src/backend/model_operations.py
CHANGED
|
@@ -215,16 +215,34 @@ class SummaryGenerator:
|
|
| 215 |
{"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
|
| 216 |
else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
|
| 217 |
temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
|
| 218 |
-
max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
|
| 219 |
)
|
| 220 |
# print(response)
|
| 221 |
result = response.choices[0].message.content
|
| 222 |
print(result)
|
| 223 |
return result
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
elif 'gemini' in self.model_id.lower():
|
| 226 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
| 227 |
-
gemini_model_id_map = {'gemini-1.5-pro-exp-0827':'gemini-pro-experimental', 'gemini-1.5-flash-exp-0827': 'gemini-flash-experimental'}
|
| 228 |
model = GenerativeModel(
|
| 229 |
self.model_id.lower().split('google/')[-1],
|
| 230 |
system_instruction = [system_prompt]
|
|
@@ -289,21 +307,23 @@ class SummaryGenerator:
|
|
| 289 |
return response
|
| 290 |
|
| 291 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
|
|
|
| 292 |
client = anthropic.Anthropic()
|
| 293 |
message = client.messages.create(
|
| 294 |
model=self.model_id.split('/')[-1],
|
| 295 |
-
max_tokens=
|
| 296 |
temperature=0,
|
| 297 |
system=system_prompt,
|
| 298 |
messages=[
|
| 299 |
{
|
| 300 |
"role": "user",
|
| 301 |
-
"content": [
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
]
|
|
|
|
| 307 |
}
|
| 308 |
]
|
| 309 |
)
|
|
@@ -311,15 +331,17 @@ class SummaryGenerator:
|
|
| 311 |
print(result)
|
| 312 |
return result
|
| 313 |
|
| 314 |
-
elif 'command-r' in self.model_id.lower():
|
| 315 |
-
co = cohere.
|
| 316 |
response = co.chat(
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
| 319 |
],
|
| 320 |
-
|
| 321 |
)
|
| 322 |
-
result = response.text
|
| 323 |
print(result)
|
| 324 |
return result
|
| 325 |
|
|
@@ -375,7 +397,10 @@ class SummaryGenerator:
|
|
| 375 |
trust_remote_code=True
|
| 376 |
)
|
| 377 |
else:
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
| 379 |
print("Tokenizer loaded")
|
| 380 |
if 'jamba' in self.model_id.lower():
|
| 381 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
|
@@ -390,8 +415,14 @@ class SummaryGenerator:
|
|
| 390 |
)
|
| 391 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
else:
|
| 394 |
-
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto"
|
| 395 |
# print(self.local_model.device)
|
| 396 |
print("Local model loaded")
|
| 397 |
|
|
@@ -419,7 +450,7 @@ class SummaryGenerator:
|
|
| 419 |
# gemma-1.1, mistral-7b does not accept system role
|
| 420 |
{"role": "user", "content": system_prompt + '\n' + user_prompt}
|
| 421 |
]
|
| 422 |
-
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
|
| 423 |
|
| 424 |
elif 'phi-2' in self.model_id.lower():
|
| 425 |
prompt = system_prompt + '\n' + user_prompt
|
|
@@ -451,20 +482,25 @@ class SummaryGenerator:
|
|
| 451 |
# print(prompt)
|
| 452 |
# print('-'*50)
|
| 453 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
| 458 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
| 459 |
outputs = [
|
| 460 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
| 461 |
]
|
| 462 |
|
| 463 |
-
|
| 464 |
if 'qwen2-vl' in self.model_id.lower():
|
| 465 |
result = self.processor.batch_decode(
|
| 466 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 467 |
)[0]
|
|
|
|
|
|
|
| 468 |
else:
|
| 469 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 470 |
|
|
|
|
| 215 |
{"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
|
| 216 |
else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
|
| 217 |
temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
|
| 218 |
+
# max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
|
| 219 |
)
|
| 220 |
# print(response)
|
| 221 |
result = response.choices[0].message.content
|
| 222 |
print(result)
|
| 223 |
return result
|
| 224 |
|
| 225 |
+
elif 'grok' in self.model_id.lower(): # xai
|
| 226 |
+
XAI_API_KEY = os.getenv("XAI_API_KEY")
|
| 227 |
+
client = OpenAI(
|
| 228 |
+
api_key=XAI_API_KEY,
|
| 229 |
+
base_url="https://api.x.ai/v1",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
completion = client.chat.completions.create(
|
| 233 |
+
model=self.model_id.split('/')[-1],
|
| 234 |
+
messages=[
|
| 235 |
+
{"role": "system", "content": system_prompt},
|
| 236 |
+
{"role": "user", "content": user_prompt},
|
| 237 |
+
],
|
| 238 |
+
temperature=0.0
|
| 239 |
+
)
|
| 240 |
+
result = completion.choices[0].message.content
|
| 241 |
+
print(result)
|
| 242 |
+
return result
|
| 243 |
+
|
| 244 |
elif 'gemini' in self.model_id.lower():
|
| 245 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
|
|
|
| 246 |
model = GenerativeModel(
|
| 247 |
self.model_id.lower().split('google/')[-1],
|
| 248 |
system_instruction = [system_prompt]
|
|
|
|
| 307 |
return response
|
| 308 |
|
| 309 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
| 310 |
+
print('using Anthropic API')
|
| 311 |
client = anthropic.Anthropic()
|
| 312 |
message = client.messages.create(
|
| 313 |
model=self.model_id.split('/')[-1],
|
| 314 |
+
max_tokens=1024,
|
| 315 |
temperature=0,
|
| 316 |
system=system_prompt,
|
| 317 |
messages=[
|
| 318 |
{
|
| 319 |
"role": "user",
|
| 320 |
+
# "content": [
|
| 321 |
+
# {
|
| 322 |
+
# "type": "text",
|
| 323 |
+
# "text": user_prompt
|
| 324 |
+
# }
|
| 325 |
+
# ]
|
| 326 |
+
"content": user_prompt
|
| 327 |
}
|
| 328 |
]
|
| 329 |
)
|
|
|
|
| 331 |
print(result)
|
| 332 |
return result
|
| 333 |
|
| 334 |
+
elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
|
| 335 |
+
co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
|
| 336 |
response = co.chat(
|
| 337 |
+
model=self.model_id.split('/')[-1],
|
| 338 |
+
messages=[
|
| 339 |
+
{"role": "system", "content": system_prompt},
|
| 340 |
+
{"role": "user", "content": user_prompt}
|
| 341 |
],
|
| 342 |
+
temperature=0,
|
| 343 |
)
|
| 344 |
+
result = response.message.content[0].text
|
| 345 |
print(result)
|
| 346 |
return result
|
| 347 |
|
|
|
|
| 397 |
trust_remote_code=True
|
| 398 |
)
|
| 399 |
else:
|
| 400 |
+
if 'ragamuffin' in self.model_id.lower():
|
| 401 |
+
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
|
| 402 |
+
else:
|
| 403 |
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
| 404 |
print("Tokenizer loaded")
|
| 405 |
if 'jamba' in self.model_id.lower():
|
| 406 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
|
|
|
| 415 |
)
|
| 416 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
| 417 |
|
| 418 |
+
# elif 'ragamuffin' in self.model_id.lower():
|
| 419 |
+
# print('Using ragamuffin')
|
| 420 |
+
# self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
|
| 421 |
+
# torch_dtype=torch.bfloat16, # forcing bfloat16 for now
|
| 422 |
+
# attn_implementation="flash_attention_2")
|
| 423 |
+
|
| 424 |
else:
|
| 425 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
|
| 426 |
# print(self.local_model.device)
|
| 427 |
print("Local model loaded")
|
| 428 |
|
|
|
|
| 450 |
# gemma-1.1, mistral-7b does not accept system role
|
| 451 |
{"role": "user", "content": system_prompt + '\n' + user_prompt}
|
| 452 |
]
|
| 453 |
+
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 454 |
|
| 455 |
elif 'phi-2' in self.model_id.lower():
|
| 456 |
prompt = system_prompt + '\n' + user_prompt
|
|
|
|
| 482 |
# print(prompt)
|
| 483 |
# print('-'*50)
|
| 484 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 485 |
+
if 'granite' in self.model_id.lower():
|
| 486 |
+
self.local_model.eval()
|
| 487 |
+
outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
|
| 488 |
+
else:
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
|
| 491 |
+
if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
|
| 492 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
| 493 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
| 494 |
outputs = [
|
| 495 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
| 496 |
]
|
| 497 |
|
|
|
|
| 498 |
if 'qwen2-vl' in self.model_id.lower():
|
| 499 |
result = self.processor.batch_decode(
|
| 500 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 501 |
)[0]
|
| 502 |
+
# elif 'granite' in self.model_id.lower():
|
| 503 |
+
# result = self.tokenizer.batch_decode(outputs)[0]
|
| 504 |
else:
|
| 505 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 506 |
|