Aryabhata-Demo / app.py
pw-ai-research's picture
Update app.py
fdedf62 verified
raw
history blame
1.72 kB
# import spaces
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import StopStringCriteria, StoppingCriteriaList
from datasets import load_dataset, concatenate_datasets
import torch
from vllm import LLM, SamplingParams
llm = LLM(model="PhysicsWallahAI/Aryabhata-1.0")
sampling_params = SamplingParams(temperature=0.0, max_tokens=4*1024, stop=["<|im_end|>", "<|end|>", "<im_start|>", "⁠```python\n", "⁠<|im_start|>", "]}}]}}]"])
def process_questions(example):
example["question_text"] = example["question"]
options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])])
example["question_text"] += "\n" + options
example["question_text"] = [example["question_text"]]
return example
dataset = concatenate_datasets([
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"),
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"),
])
examples = dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]
def generate_answer_stream(question):
messages = [
{'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'},
{'role': 'user', 'content': question}
]
results = llm.chat(messages, sampling_params)
return results[0].outputs[0].text.strip()
demo = gr.Interface(
fn=generate_answer_stream,
inputs=gr.Textbox(lines=4, label="Enter a Math Question"),
outputs=gr.Textbox(label="Model's Response", lines=10),
# examples=examples,
title="Aryabhata 1.0",
description="",
)
if __name__ == "__main__":
demo.launch()