pw-ai-research commited on
Commit
c835d4d
·
verified ·
1 Parent(s): 02b0589

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from transformers import StopStringCriteria, StoppingCriteriaList
5
+
6
+ from datasets import load_dataset, concatenate_datasets
7
+ import torch
8
+ import threading
9
+
10
+ model_id = "PhysicsWallahAI/Aryabhatta-1.0"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
13
+
14
+ def process_questions(example):
15
+ example["question_text"] = example["question"]
16
+ options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])])
17
+ example["question_text"] += "\n" + options
18
+ return example
19
+
20
+ dataset = concatenate_datasets([
21
+ load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"),
22
+ load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"),
23
+ ])
24
+ examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
25
+
26
+
27
+
28
+ # add options
29
+
30
+ stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]"]
31
+
32
+
33
+ def strip_bad_tokens(s, stop_strings):
34
+ for suffix in stop_strings:
35
+ if s.endswith(suffix):
36
+ return s[:-len(suffix)]
37
+ return s
38
+
39
+ def generate_answer_stream(question):
40
+ messages = [
41
+ {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
42
+ {'role': 'user', 'content': question}
43
+ ]
44
+
45
+ text = tokenizer.apply_chat_template(
46
+ messages,
47
+ tokenize=False,
48
+ add_generation_prompt=True
49
+ )
50
+
51
+ inputs = tokenizer([text], return_tensors="pt")
52
+
53
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
54
+ stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)])
55
+
56
+
57
+ thread = threading.Thread(
58
+ target=model.generate,
59
+ kwargs=dict(
60
+ **inputs,
61
+ streamer=streamer,
62
+ max_new_tokens=4096,
63
+ stopping_criteria=stopping,
64
+ )
65
+ )
66
+ thread.start()
67
+
68
+ output = ""
69
+ for token in streamer:
70
+ output += token
71
+ output = strip_bad_tokens(output, stop_strings)
72
+ yield output
73
+
74
+ demo = gr.Interface(
75
+ fn=generate_answer_stream,
76
+ inputs=gr.Textbox(lines=4, label="Enter a Math Question"),
77
+ outputs=gr.Textbox(label="Model's Response", lines=10),
78
+ examples=examples,
79
+ title="Aryabhatta 1.0 Demo",
80
+ description=""
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch()