Spaces:
Sleeping
Sleeping
File size: 2,606 Bytes
0b7abd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
from threading import Thread
# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1" # Verify exact model ID on Hugging Face Hub
DEFAULT_MAX_NEW_TOKENS = 512
# Load model and tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype="auto",
# load_in_4bit=True # Uncomment for 4-bit quantization
)
except Exception as e:
raise gr.Error(f"Error loading model: {str(e)}")
def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Streamer for real-time output
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Start generation in a thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1 Demo")
with gr.Row():
input_text = gr.Textbox(
label="Input Prompt",
placeholder="Enter your prompt here...",
lines=5
)
output_text = gr.Textbox(
label="Generated Response",
interactive=False,
lines=10
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens = gr.Slider(
minimum=64,
maximum=2048,
value=DEFAULT_MAX_NEW_TOKENS,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
label="Top-p"
)
submit_btn = gr.Button("Generate")
submit_btn.click(
fn=generate_text,
inputs=[input_text, max_tokens, temperature, top_p],
outputs=output_text,
api_name="generate"
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |