import torch import spaces import gradio as gr from threading import Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList ) MODEL_ID = "Daemontatox/PathFinderAI-S1" # MODEL_ID = "Daemontatox/Research_PathfinderAI" DEFAULT_SYSTEM_PROMPT = """ Respond in the following format: [reasoning] [your reasoning] [/reasoning] [answer] [your answer] [/answer] put your final answer within $boxed{}$ """ # You can modify the default system instructions here CSS = """ .gr-chatbot { min-height: 500px; border-radius: 15px; } .special-tag { color: #2ecc71; font-weight: 600; } footer { display: none !important; } """ class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # Stop when the EOS token is generated. return input_ids[0][-1] == tokenizer.eos_token_id def initialize_model(): # Enable 4-bit quantization for faster inference and lower memory usage. quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cuda", quantization_config=quantization_config, torch_dtype=torch.bfloat16, trust_remote_code=True ) model.to("cuda") model.eval() # set evaluation mode to disable gradients and speed up inference return model, tokenizer def format_response(text): # List of replacements to format key tokens with HTML for styling. replacements = [ ("[Understand]", '\n[Understand]\n'), ("[Reason]", '\n[Reason]\n'), ("[/Reason]", '\n[/Reason]\n'), ("[Answer]", '\n[Answer]\n'), ("[/Answer]", '\n[/Answer]\n'), ] for old, new in replacements: text = text.replace(old, new) return text @spaces.GPU(duration=120) def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty): # Build the conversation history. conversation = [{"role": "system", "content": system_prompt}] for user_msg, bot_msg in chat_history: conversation.append({"role": "user", "content": user_msg}) conversation.append({"role": "assistant", "content": bot_msg}) conversation.append({"role": "user", "content": message}) # Tokenize the conversation. (This assumes the tokenizer has an apply_chat_template method.) input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Setup the streamer to yield new tokens as they are generated. streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) # Prepare generation parameters including extra customization options. generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "stopping_criteria": StoppingCriteriaList([StopOnTokens()]) } # Run the generation inside a no_grad block for speed. def generate_inference(): with torch.inference_mode(): model.generate(**generate_kwargs) Thread(target=generate_inference, daemon=True).start() # Stream the output tokens. partial_message = "" new_history = chat_history + [(message, "")] for new_token in streamer: partial_message += new_token formatted = format_response(partial_message) new_history[-1] = (message, formatted + "▌") yield new_history # Final update without the cursor. new_history[-1] = (message, format_response(partial_message)) yield new_history # Initialize the model and tokenizer globally. model, tokenizer = initialize_model() with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.Markdown("""

🧠 AI Reasoning Assistant

Ask me hard questions and see the reasoning unfold.

""") chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot") msg = gr.Textbox(label="Your Question", placeholder="Type your question...") with gr.Accordion("⚙️ Settings", open=False): system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions") temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)") max_tokens = gr.Slider(128, 8192, 4096, label="Max Response Length") top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)") top_k = gr.Slider(0, 100, value=50, label="Top K") repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty") clear = gr.Button("Clear History") # Link the input textbox with the generation function. msg.submit( generate_response, [msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty], chatbot, show_progress=True ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.queue().launch()