Aryabhata-Demo / app.py
pw-ai-research's picture
Create app.py
c835d4d verified
raw
history blame
2.51 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import StopStringCriteria, StoppingCriteriaList
from datasets import load_dataset, concatenate_datasets
import torch
import threading
model_id = "PhysicsWallahAI/Aryabhatta-1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
def process_questions(example):
example["question_text"] = example["question"]
options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])])
example["question_text"] += "\n" + options
return example
dataset = concatenate_datasets([
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"),
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"),
])
examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
# add options
stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]"]
def strip_bad_tokens(s, stop_strings):
for suffix in stop_strings:
if s.endswith(suffix):
return s[:-len(suffix)]
return s
def generate_answer_stream(question):
messages = [
{'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
{'role': 'user', 'content': question}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)])
thread = threading.Thread(
target=model.generate,
kwargs=dict(
**inputs,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=stopping,
)
)
thread.start()
output = ""
for token in streamer:
output += token
output = strip_bad_tokens(output, stop_strings)
yield output
demo = gr.Interface(
fn=generate_answer_stream,
inputs=gr.Textbox(lines=4, label="Enter a Math Question"),
outputs=gr.Textbox(label="Model's Response", lines=10),
examples=examples,
title="Aryabhatta 1.0 Demo",
description=""
)
if __name__ == "__main__":
demo.launch()