Daemontatox commited on
Commit
c701791
·
verified ·
1 Parent(s): c8e2710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -171
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import os
2
- import re
3
- import time
4
  import torch
5
  import spaces
6
  import gradio as gr
@@ -14,44 +11,26 @@ from transformers import (
14
  StoppingCriteriaList
15
  )
16
 
17
- # Configuration Constants
18
  MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
19
 
20
- # Enhanced System Prompt
21
  DEFAULT_SYSTEM_PROMPT = """You are an Expert Reasoning Assistant. Follow these steps:
22
  [Understand]: Analyze key elements and clarify objectives
23
  [Plan]: Outline step-by-step methodology
24
  [Reason]: Execute plan with detailed analysis
25
  [Verify]: Check logic and evidence
26
- [Conclude]: Present structured conclusion
27
 
28
- Use these section headers and maintain technical accuracy with clear explanations."""
29
-
30
- # UI Configuration
31
- TITLE = """
32
- <h1 align="center" style="color: #2d3436; margin-bottom: 0">🧠 AI Reasoning Assistant</h1>
33
- <p align="center" style="color: #636e72; margin-top: 0">DeepSeek-R1-Distill-Qwen-14B</p>
34
- """
35
  CSS = """
36
- .gr-chatbot { min-height: 500px !important; border-radius: 15px !important; }
37
- .message-wrap pre { background: #f8f9fa !important; padding: 15px !important; }
38
- .thinking-tag { color: #2ecc71; font-weight: 600; }
39
- .plan-tag { color: #e67e22; font-weight: 600; }
40
- .conclude-tag { color: #3498db; font-weight: 600; }
41
- .control-panel { background: #f8f9fa !important; padding: 20px !important; }
42
- footer { visibility: hidden !important; }
43
  """
44
 
45
  class StopOnTokens(StoppingCriteria):
46
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
47
- stop_ids = [0] # Add custom stop tokens here
48
- return input_ids[0][-1] in stop_ids
49
 
50
  def initialize_model():
51
- """Initialize model with safety checks"""
52
- if not torch.cuda.is_available():
53
- raise RuntimeError("CUDA is required for this application")
54
-
55
  quantization_config = BitsAndBytesConfig(
56
  load_in_4bit=True,
57
  bnb_4bit_compute_dtype=torch.bfloat16,
@@ -73,150 +52,81 @@ def initialize_model():
73
  return model, tokenizer
74
 
75
  def format_response(text):
76
- """Enhanced formatting with syntax highlighting for reasoning steps"""
77
- formatted = text.replace("[Understand]", '\n<strong class="thinking-tag">[Understand]</strong>\n')
78
- formatted = formatted.replace("[Plan]", '\n<strong class="plan-tag">[Plan]</strong>\n')
79
- formatted = formatted.replace("[Conclude]", '\n<strong class="conclude-tag">[Conclude]</strong>\n')
80
- return formatted
81
-
82
- @spaces.GPU(duration=120)
83
- def chat_response(
84
- message: str,
85
- history: list,
86
- system_prompt: str,
87
- temperature: float = 0.3,
88
- max_new_tokens: int = 2048,
89
- top_p: float = 0.9,
90
- top_k: int = 50,
91
- penalty: float = 1.2,
92
- ):
93
- """Improved streaming generator with error handling"""
94
- try:
95
- conversation = [{"role": "system", "content": system_prompt}]
96
- for user, assistant in history:
97
- conversation.extend([
98
- {"role": "user", "content": user},
99
- {"role": "assistant", "content": assistant}
100
- ])
101
- conversation.append({"role": "user", "content": message})
102
-
103
- input_ids = tokenizer.apply_chat_template(
104
- conversation,
105
- add_generation_prompt=True,
106
- return_tensors="pt"
107
- ).to(model.device)
108
-
109
- streamer = TextIteratorStreamer(
110
- tokenizer,
111
- timeout=30,
112
- skip_prompt=True,
113
- skip_special_tokens=True
114
- )
115
-
116
- generate_kwargs = dict(
117
- input_ids=input_ids,
118
- max_new_tokens=max_new_tokens,
119
- temperature=temperature,
120
- top_p=top_p,
121
- top_k=top_k,
122
- repetition_penalty=penalty,
123
- streamer=streamer,
124
- stopping_criteria=StoppingCriteriaList([StopOnTokens()])
125
- )
126
-
127
- buffer = []
128
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
129
- thread.start()
130
-
131
- for new_text in streamer:
132
- buffer.append(new_text)
133
- partial_result = "".join(buffer)
134
-
135
- # Check for complete sections
136
- if any(tag in partial_result for tag in ["[Understand]", "[Plan]", "[Conclude]"]):
137
- yield format_response(partial_result)
138
- else:
139
- yield format_response(partial_result + " ")
140
-
141
- # Final formatting pass
142
- yield format_response("".join(buffer))
143
-
144
- except Exception as e:
145
- yield f"⚠️ Error generating response: {str(e)}"
146
-
147
- def create_examples():
148
- """Enhanced examples with diverse use cases"""
149
- return [
150
- ["Explain quantum entanglement in simple terms"],
151
- ["Design a study plan for learning machine learning"],
152
- ["Compare blockchain and traditional databases"],
153
- ["How would you optimize AWS costs for a startup?"],
154
- ["Explain the ethical implications of CRISPR technology"]
155
- ]
156
-
157
- def main():
158
- """Improved UI layout and interactions"""
159
- global model, tokenizer
160
- model, tokenizer = initialize_model()
161
-
162
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
163
- gr.HTML(TITLE)
164
-
165
- with gr.Row():
166
- with gr.Column(scale=3):
167
- chatbot = gr.Chatbot(
168
- elem_id="chatbot",
169
- bubble_full_width=False,
170
- show_copy_button=True,
171
- render=False
172
- )
173
- msg = gr.Textbox(
174
- placeholder="Enter your question...",
175
- label="Ask the Expert",
176
- container=False
177
- )
178
- with gr.Row():
179
- submit_btn = gr.Button("Send", variant="primary")
180
- clear_btn = gr.Button("Clear", variant="secondary")
181
-
182
- with gr.Column(scale=1, elem_classes="control-panel"):
183
- gr.Examples(
184
- examples=create_examples(),
185
- inputs=msg,
186
- label="Example Queries",
187
- examples_per_page=5
188
- )
189
-
190
- with gr.Accordion("⚙️ Generation Parameters", open=False):
191
- system_prompt = gr.TextArea(
192
- value=DEFAULT_SYSTEM_PROMPT,
193
- label="System Instructions",
194
- lines=5
195
- )
196
- temperature = gr.Slider(0, 2, value=0.7, label="Creativity")
197
- max_tokens = gr.Slider(128, 4096, value=2048, step=128, label="Max Tokens")
198
- top_p = gr.Slider(0, 1, value=0.9, step=0.05, label="Focus (Top-p)")
199
- penalty = gr.Slider(1, 2, value=1.2, step=0.1, label="Repetition Control")
200
-
201
- # Event handling
202
- msg.submit(
203
- chat_response,
204
- [msg, chatbot, system_prompt, temperature, max_tokens, top_p, penalty],
205
- [msg, chatbot],
206
- show_progress="hidden"
207
- ).then(lambda: "", None, msg)
208
-
209
- submit_btn.click(
210
- chat_response,
211
- [msg, chatbot, system_prompt, temperature, max_tokens, top_p, penalty],
212
- [msg, chatbot],
213
- show_progress="hidden"
214
- ).then(lambda: "", None, msg)
215
-
216
- clear_btn.click(lambda: None, None, chatbot, queue=False)
217
-
218
- return demo
219
 
220
  if __name__ == "__main__":
221
- demo = main()
222
- demo.queue(max_size=20).launch()
 
 
 
 
1
  import torch
2
  import spaces
3
  import gradio as gr
 
11
  StoppingCriteriaList
12
  )
13
 
 
14
  MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
15
 
 
16
  DEFAULT_SYSTEM_PROMPT = """You are an Expert Reasoning Assistant. Follow these steps:
17
  [Understand]: Analyze key elements and clarify objectives
18
  [Plan]: Outline step-by-step methodology
19
  [Reason]: Execute plan with detailed analysis
20
  [Verify]: Check logic and evidence
21
+ [Conclude]: Present structured conclusion"""
22
 
 
 
 
 
 
 
 
23
  CSS = """
24
+ .gr-chatbot { min-height: 500px; border-radius: 15px; }
25
+ .special-tag { color: #2ecc71; font-weight: 600; }
26
+ footer { display: none !important; }
 
 
 
 
27
  """
28
 
29
  class StopOnTokens(StoppingCriteria):
30
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
31
+ return input_ids[0][-1] == tokenizer.eos_token_id
 
32
 
33
  def initialize_model():
 
 
 
 
34
  quantization_config = BitsAndBytesConfig(
35
  load_in_4bit=True,
36
  bnb_4bit_compute_dtype=torch.bfloat16,
 
52
  return model, tokenizer
53
 
54
  def format_response(text):
55
+ return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \
56
+ .replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n') \
57
+ .replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n')
58
+
59
+ @spaces.GPU
60
+ def generate_response(message, chat_history, system_prompt, temperature, max_tokens):
61
+ # Create conversation history for model
62
+ conversation = [{"role": "system", "content": system_prompt}]
63
+ for user_msg, bot_msg in chat_history:
64
+ conversation.extend([
65
+ {"role": "user", "content": user_msg},
66
+ {"role": "assistant", "content": bot_msg}
67
+ ])
68
+ conversation.append({"role": "user", "content": message})
69
+
70
+ # Tokenize input
71
+ input_ids = tokenizer.apply_chat_template(
72
+ conversation,
73
+ add_generation_prompt=True,
74
+ return_tensors="pt"
75
+ ).to(model.device)
76
+
77
+ # Setup streaming
78
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
79
+ generate_kwargs = dict(
80
+ input_ids=input_ids,
81
+ streamer=streamer,
82
+ max_new_tokens=max_tokens,
83
+ temperature=temperature,
84
+ stopping_criteria=StoppingCriteriaList([StopOnTokens()])
85
+ )
86
+
87
+ # Start generation thread
88
+ Thread(target=model.generate, kwargs=generate_kwargs).start()
89
+
90
+ # Initialize response buffer
91
+ partial_message = ""
92
+ new_history = chat_history + [(message, "")]
93
+
94
+ # Stream response
95
+ for new_token in streamer:
96
+ partial_message += new_token
97
+ formatted = format_response(partial_message)
98
+ new_history[-1] = (message, formatted + "▌")
99
+ yield new_history
100
+
101
+ # Final update without cursor
102
+ new_history[-1] = (message, format_response(partial_message))
103
+ yield new_history
104
+
105
+ model, tokenizer = initialize_model()
106
+
107
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
108
+ gr.Markdown("""
109
+ <h1 align="center">🧠 AI Reasoning Assistant</h1>
110
+ <p align="center">DeepSeek-R1-Distill-Qwen-14B</p>
111
+ """)
112
+
113
+ chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
114
+ msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
115
+
116
+ with gr.Accordion("⚙️ Settings", open=False):
117
+ system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
118
+ temperature = gr.Slider(0, 1, value=0.7, label="Creativity")
119
+ max_tokens = gr.Slider(128, 4096, value=2048, label="Max Response Length")
120
+
121
+ clear = gr.Button("Clear History")
122
+
123
+ msg.submit(
124
+ generate_response,
125
+ [msg, chatbot, system_prompt, temperature, max_tokens],
126
+ [chatbot],
127
+ show_progress="hidden"
128
+ )
129
+ clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
+ demo.queue().launch()