pw-ai-research commited on
Commit
fdedf62
·
verified ·
1 Parent(s): 901a060

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -45
app.py CHANGED
@@ -6,11 +6,12 @@ from transformers import StopStringCriteria, StoppingCriteriaList
6
 
7
  from datasets import load_dataset, concatenate_datasets
8
  import torch
9
- import threading
 
 
 
 
10
 
11
- model_id = "PhysicsWallahAI/Aryabhata-1.0"
12
- tokenizer = AutoTokenizer.from_pretrained(model_id)
13
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
14
 
15
  def process_questions(example):
16
  example["question_text"] = example["question"]
@@ -26,53 +27,14 @@ dataset = concatenate_datasets([
26
  examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
27
 
28
 
29
-
30
- # add options
31
-
32
- stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]"]
33
-
34
-
35
- def strip_bad_tokens(s, stop_strings):
36
- for suffix in stop_strings:
37
- if s.endswith(suffix):
38
- return s[:-len(suffix)]
39
- return s
40
-
41
  def generate_answer_stream(question):
42
  messages = [
43
  {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
44
  {'role': 'user', 'content': question}
45
  ]
46
 
47
- text = tokenizer.apply_chat_template(
48
- messages,
49
- tokenize=False,
50
- add_generation_prompt=True
51
- )
52
-
53
- inputs = tokenizer([text], return_tensors="pt")
54
-
55
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
56
- stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)])
57
-
58
-
59
- thread = threading.Thread(
60
- target=model.generate,
61
- kwargs=dict(
62
- **inputs,
63
- streamer=streamer,
64
- max_new_tokens=4096,
65
- stopping_criteria=stopping,
66
- )
67
- )
68
- thread.start()
69
-
70
- output = ""
71
- for token in streamer:
72
- print(token)
73
- output += token
74
- output = strip_bad_tokens(output, stop_strings)
75
- yield output
76
 
77
  demo = gr.Interface(
78
  fn=generate_answer_stream,
 
6
 
7
  from datasets import load_dataset, concatenate_datasets
8
  import torch
9
+ from vllm import LLM, SamplingParams
10
+
11
+ llm = LLM(model="PhysicsWallahAI/Aryabhata-1.0")
12
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=4*1024, stop=["<|im_end|>", "<|end|>", "<im_start|>", "⁠```python\n", "⁠<|im_start|>", "]}}]}}]"])
13
+
14
 
 
 
 
15
 
16
  def process_questions(example):
17
  example["question_text"] = example["question"]
 
27
  examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate_answer_stream(question):
31
  messages = [
32
  {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
33
  {'role': 'user', 'content': question}
34
  ]
35
 
36
+ results = llm.chat(messages, sampling_params)
37
+ return results[0].outputs[0].text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  demo = gr.Interface(
40
  fn=generate_answer_stream,