shivrajkarewar's picture
Create app.py
0b7abd6 verified
raw
history blame
2.61 kB
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
from threading import Thread
# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1" # Verify exact model ID on Hugging Face Hub
DEFAULT_MAX_NEW_TOKENS = 512
# Load model and tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype="auto",
# load_in_4bit=True # Uncomment for 4-bit quantization
)
except Exception as e:
raise gr.Error(f"Error loading model: {str(e)}")
def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Streamer for real-time output
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Start generation in a thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1 Demo")
with gr.Row():
input_text = gr.Textbox(
label="Input Prompt",
placeholder="Enter your prompt here...",
lines=5
)
output_text = gr.Textbox(
label="Generated Response",
interactive=False,
lines=10
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens = gr.Slider(
minimum=64,
maximum=2048,
value=DEFAULT_MAX_NEW_TOKENS,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
label="Top-p"
)
submit_btn = gr.Button("Generate")
submit_btn.click(
fn=generate_text,
inputs=[input_text, max_tokens, temperature, top_p],
outputs=output_text,
api_name="generate"
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)