File size: 2,851 Bytes
df990b7
 
99566a2
c835d4d
 
 
 
 
 
 
db1f4a2
c835d4d
db1f4a2
 
df990b7
c835d4d
 
 
 
 
f3def8e
c835d4d
 
 
 
 
 
b8e5947
 
c835d4d
db1f4a2
 
 
34cb523
db1f4a2
 
 
 
 
 
 
 
c835d4d
 
 
 
 
 
db1f4a2
 
 
 
 
 
6299cb7
db1f4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c835d4d
 
 
 
6683cdb
8a973f1
ee07da7
e9e1ff1
c835d4d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# import os
# os.system("pip install flash-attn --no-build-isolation")

import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import StopStringCriteria, StoppingCriteriaList

from datasets import load_dataset, concatenate_datasets
import torch
import threading

model_id = "PhysicsWallahAI/Aryabhata-1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)#, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="flash_attention_2")

def process_questions(example):
    example["question_text"] = example["question"]
    options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])])
    example["question_text"] += "\n" + options
    example["question_text"] = example["question_text"]
    return example

dataset = concatenate_datasets([
    load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"),
    load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"),
])
examples = list(dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"])
print(examples[0])


# add options

stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]", " <im_start>"]


def strip_bad_tokens(s, stop_strings):
    for suffix in stop_strings:
        if s.endswith(suffix):
            return s[:-len(suffix)]
    return s

def generate_answer_stream(question):
    messages = [
        {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
        {'role': 'user', 'content': question}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer([text], return_tensors="pt")#.to("cuda")

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)])

    
    thread = threading.Thread(
        target=model.generate,
        kwargs=dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=stopping,
        )
    )
    thread.start()
    
    output = ""
    for token in streamer:
        print(token)
        output += token
        output = strip_bad_tokens(output, stop_strings)
        yield output

demo = gr.Interface(
    fn=generate_answer_stream,
    inputs=gr.Textbox(lines=4, label="Enter a Math Question"),
    outputs=gr.Textbox(label="Model's Response"),
    examples=examples,
    title="Aryabhata 1.0",
    description="We are disabling GPUs on this space, we will hosting the model on a separate space soon",
)

if __name__ == "__main__":
    demo.launch()