import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer ) from threading import Thread # Configuration MODEL_NAME = "deepseek-ai/DeepSeek-R1" DEFAULT_MAX_NEW_TOKENS = 512 # Load model and tokenizer WITH TRUSTED CODE try: tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True # <-- ADDED HERE ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True, # <-- ADDED HERE # load_in_4bit=True # Uncomment for quantization ) except Exception as e: raise gr.Error(f"Error loading model: {str(e)}") def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Streamer for real-time output streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # Start generation in a thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield generated text generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text # Gradio interface with gr.Blocks() as demo: gr.Markdown("# DeepSeek-R1 Demo") with gr.Row(): input_text = gr.Textbox( label="Input Prompt", placeholder="Enter your prompt here...", lines=5 ) output_text = gr.Textbox( label="Generated Response", interactive=False, lines=10 ) with gr.Accordion("Advanced Settings", open=False): max_tokens = gr.Slider( minimum=64, maximum=2048, value=DEFAULT_MAX_NEW_TOKENS, label="Max New Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, label="Temperature" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, label="Top-p" ) submit_btn = gr.Button("Generate") submit_btn.click( fn=generate_text, inputs=[input_text, max_tokens, temperature, top_p], outputs=output_text, api_name="generate" ) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)