shivrajkarewar commited on
Commit
0b7abd6
·
verified ·
1 Parent(s): 175d3a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ TextIteratorStreamer
6
+ )
7
+ from threading import Thread
8
+
9
+ # Configuration
10
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1" # Verify exact model ID on Hugging Face Hub
11
+ DEFAULT_MAX_NEW_TOKENS = 512
12
+
13
+ # Load model and tokenizer
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_NAME,
18
+ device_map="auto",
19
+ torch_dtype="auto",
20
+ # load_in_4bit=True # Uncomment for 4-bit quantization
21
+ )
22
+ except Exception as e:
23
+ raise gr.Error(f"Error loading model: {str(e)}")
24
+
25
+ def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9):
26
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
27
+
28
+ # Streamer for real-time output
29
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
30
+
31
+ generation_kwargs = dict(
32
+ **inputs,
33
+ streamer=streamer,
34
+ max_new_tokens=max_new_tokens,
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ do_sample=True
38
+ )
39
+
40
+ # Start generation in a thread
41
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
42
+ thread.start()
43
+
44
+ # Yield generated text
45
+ generated_text = ""
46
+ for new_text in streamer:
47
+ generated_text += new_text
48
+ yield generated_text
49
+
50
+ # Gradio interface
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("# DeepSeek-R1 Demo")
53
+
54
+ with gr.Row():
55
+ input_text = gr.Textbox(
56
+ label="Input Prompt",
57
+ placeholder="Enter your prompt here...",
58
+ lines=5
59
+ )
60
+ output_text = gr.Textbox(
61
+ label="Generated Response",
62
+ interactive=False,
63
+ lines=10
64
+ )
65
+
66
+ with gr.Accordion("Advanced Settings", open=False):
67
+ max_tokens = gr.Slider(
68
+ minimum=64,
69
+ maximum=2048,
70
+ value=DEFAULT_MAX_NEW_TOKENS,
71
+ label="Max New Tokens"
72
+ )
73
+ temperature = gr.Slider(
74
+ minimum=0.1,
75
+ maximum=1.5,
76
+ value=0.7,
77
+ label="Temperature"
78
+ )
79
+ top_p = gr.Slider(
80
+ minimum=0.1,
81
+ maximum=1.0,
82
+ value=0.9,
83
+ label="Top-p"
84
+ )
85
+
86
+ submit_btn = gr.Button("Generate")
87
+ submit_btn.click(
88
+ fn=generate_text,
89
+ inputs=[input_text, max_tokens, temperature, top_p],
90
+ outputs=output_text,
91
+ api_name="generate"
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)