# import os # os.system("pip install flash-attn --no-build-isolation") import gradio as gr import transformers from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from transformers import StopStringCriteria, StoppingCriteriaList from datasets import load_dataset, concatenate_datasets import torch import threading model_id = "PhysicsWallahAI/Aryabhata-1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id)#, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="flash_attention_2") def process_questions(example): example["question_text"] = example["question"] options = "\n".join([f"{chr(65+e)}. {op}" for e, op in enumerate(example["options"])]) example["question_text"] += "\n" + options example["question_text"] = example["question_text"] return example dataset = concatenate_datasets([ load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "jan", split="test"), load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr", split="test"), ]) examples = list(dataset.map(process_questions, remove_columns=dataset.column_names)["question_text"]) print(examples[0]) # add options stop_strings = ["<|im_end|>", "<|end|>", "", "```python\n", "<|im_start|>", "]}}]}}]", " "] def strip_bad_tokens(s, stop_strings): for suffix in stop_strings: if s.endswith(suffix): return s[:-len(suffix)] return s def generate_answer_stream(question): messages = [ {'role': 'system', 'content': 'Think step-by-step; put only the final answer inside \\boxed{}.'}, {'role': 'user', 'content': question} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer([text], return_tensors="pt")#.to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) stopping = StoppingCriteriaList([StopStringCriteria(tokenizer, stop_strings)]) thread = threading.Thread( target=model.generate, kwargs=dict( **inputs, streamer=streamer, max_new_tokens=4096, stopping_criteria=stopping, ) ) thread.start() output = "" for token in streamer: print(token) output += token output = strip_bad_tokens(output, stop_strings) yield output demo = gr.Interface( fn=generate_answer_stream, inputs=gr.Textbox(lines=4, label="Enter a Math Question"), outputs=gr.Textbox(label="Model's Response"), examples=examples, title="Aryabhata 1.0", description="We are disabling GPUs on this space, we will hosting the model on a separate space soon", ) if __name__ == "__main__": demo.launch()