pw-ai-research commited on
Commit
db1f4a2
·
verified ·
1 Parent(s): fdedf62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -8
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # import spaces
2
  import gradio as gr
3
  import transformers
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
@@ -6,12 +5,11 @@ from transformers import StopStringCriteria, StoppingCriteriaList
6
 
7
  from datasets import load_dataset, concatenate_datasets
8
  import torch
9
- from vllm import LLM, SamplingParams
10
-
11
- llm = LLM(model="PhysicsWallahAI/Aryabhata-1.0")
12
- sampling_params = SamplingParams(temperature=0.0, max_tokens=4*1024, stop=["<|im_end|>", "<|end|>", "<im_start|>", "⁠```python\n", "⁠<|im_start|>", "]}}]}}]"])
13
-
14
 
 
 
 
15
 
16
  def process_questions(example):
17
  example["question_text"] = example["question"]
@@ -27,14 +25,53 @@ dataset = concatenate_datasets([
27
  examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate_answer_stream(question):
31
  messages = [
32
  {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
33
  {'role': 'user', 'content': question}
34
  ]
35
 
36
- results = llm.chat(messages, sampling_params)
37
- return results[0].outputs[0].text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  demo = gr.Interface(
40
  fn=generate_answer_stream,
 
 
1
  import gradio as gr
2
  import transformers
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
5
 
6
  from datasets import load_dataset, concatenate_datasets
7
  import torch
8
+ import threading
 
 
 
 
9
 
10
+ model_id = "PhysicsWallahAI/Aryabhata-1.0"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
13
 
14
  def process_questions(example):
15
  example["question_text"] = example["question"]
 
25
  examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
26
 
27
 
28
+
29
+ # add options
30
+
31
+ stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]"]
32
+
33
+
34
+ def strip_bad_tokens(s, stop_strings):
35
+ for suffix in stop_strings:
36
+ if s.endswith(suffix):
37
+ return s[:-len(suffix)]
38
+ return s
39
+
40
  def generate_answer_stream(question):
41
  messages = [
42
  {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
43
  {'role': 'user', 'content': question}
44
  ]
45
 
46
+ text = tokenizer.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
+
52
+ inputs = tokenizer([text], return_tensors="pt")
53
+
54
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
55
+ stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)])
56
+
57
+
58
+ thread = threading.Thread(
59
+ target=model.generate,
60
+ kwargs=dict(
61
+ **inputs,
62
+ streamer=streamer,
63
+ max_new_tokens=4096,
64
+ stopping_criteria=stopping,
65
+ )
66
+ )
67
+ thread.start()
68
+
69
+ output = ""
70
+ for token in streamer:
71
+ print(token)
72
+ output += token
73
+ output = strip_bad_tokens(output, stop_strings)
74
+ yield output
75
 
76
  demo = gr.Interface(
77
  fn=generate_answer_stream,