Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer | |
) | |
from threading import Thread | |
# Configuration | |
MODEL_NAME = "deepseek-ai/DeepSeek-R1" # Verify exact model ID on Hugging Face Hub | |
DEFAULT_MAX_NEW_TOKENS = 512 | |
# Load model and tokenizer | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
torch_dtype="auto", | |
# load_in_4bit=True # Uncomment for 4-bit quantization | |
) | |
except Exception as e: | |
raise gr.Error(f"Error loading model: {str(e)}") | |
def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Streamer for real-time output | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Start generation in a thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield generated text | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
yield generated_text | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# DeepSeek-R1 Demo") | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Input Prompt", | |
placeholder="Enter your prompt here...", | |
lines=5 | |
) | |
output_text = gr.Textbox( | |
label="Generated Response", | |
interactive=False, | |
lines=10 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
max_tokens = gr.Slider( | |
minimum=64, | |
maximum=2048, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.5, | |
value=0.7, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
label="Top-p" | |
) | |
submit_btn = gr.Button("Generate") | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[input_text, max_tokens, temperature, top_p], | |
outputs=output_text, | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |