Spaces:
Running
Running
# import os | |
# os.system("pip install flash-attn --no-build-isolation") | |
import gradio as gr | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from transformers import StopStringCriteria, StoppingCriteriaList | |
from datasets import load_dataset, concatenate_datasets | |
import torch | |
import threading | |
model_id = "PhysicsWallahAI/Aryabhata-1.0" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id)#, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="flash_attention_2") | |
def process_questions(example): | |
example["question_text"] = example["question"] | |
options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])]) | |
example["question_text"] += "\n" + options | |
example["question_text"] = example["question_text"] | |
return example | |
dataset = concatenate_datasets([ | |
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"), | |
load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"), | |
]) | |
examples = list(dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]) | |
print(examples[0]) | |
# add options | |
stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]", " <im_start>"] | |
def strip_bad_tokens(s, stop_strings): | |
for suffix in stop_strings: | |
if s.endswith(suffix): | |
return s[:-len(suffix)] | |
return s | |
def generate_answer_stream(question): | |
messages = [ | |
{'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'}, | |
{'role': 'user', 'content': question} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer([text], return_tensors="pt")#.to("cuda") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)]) | |
thread = threading.Thread( | |
target=model.generate, | |
kwargs=dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=4096, | |
stopping_criteria=stopping, | |
) | |
) | |
thread.start() | |
output = "" | |
for token in streamer: | |
print(token) | |
output += token | |
output = strip_bad_tokens(output, stop_strings) | |
yield output | |
demo = gr.Interface( | |
fn=generate_answer_stream, | |
inputs=gr.Textbox(lines=4, label="Enter a Math Question"), | |
outputs=gr.Textbox(label="Model's Response"), | |
examples=examples, | |
title="Aryabhata 1.0", | |
description="We are disabling GPUs on this space, we will hosting the model on a separate space soon", | |
) | |
if __name__ == "__main__": | |
demo.launch() |