Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install flash-attn --no-build-isolation") | |
| import gradio as gr | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from transformers import StopStringCriteria, StoppingCriteriaList | |
| from datasets import load_dataset, concatenate_datasets | |
| import torch | |
| import threading | |
| model_id = "PhysicsWallahAI/Aryabhata-1.0" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="flash_attention_2") | |
| def process_questions(example): | |
| example["question_text"] = example["question"] | |
| options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])]) | |
| example["question_text"] += "\n" + options | |
| example["question_text"] = example["question_text"] | |
| return example | |
| dataset = concatenate_datasets([ | |
| load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"), | |
| load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"), | |
| ]) | |
| examples = list(dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]) | |
| print(examples[0]) | |
| # add options | |
| stop_strings = ["<|im_end|>", "<|end|>", "<im_start|>", "```python\n", "<|im_start|>", "]}}]}}]", " <im_start>"] | |
| def strip_bad_tokens(s, stop_strings): | |
| for suffix in stop_strings: | |
| if s.endswith(suffix): | |
| return s[:-len(suffix)] | |
| return s | |
| def generate_answer_stream(question): | |
| messages = [ | |
| {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'}, | |
| {'role': 'user', 'content': question} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer([text], return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)]) | |
| thread = threading.Thread( | |
| target=model.generate, | |
| kwargs=dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=4096, | |
| stopping_criteria=stopping, | |
| ) | |
| ) | |
| thread.start() | |
| output = "" | |
| for token in streamer: | |
| print(token) | |
| output += token | |
| output = strip_bad_tokens(output, stop_strings) | |
| yield output | |
| demo = gr.Interface( | |
| fn=generate_answer_stream, | |
| inputs=gr.Textbox(lines=4, label="Enter a Math Question"), | |
| outputs=gr.Textbox(label="Model's Response"), | |
| examples=examples, | |
| title="Aryabhata 1.0", | |
| description="", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |