Abhaykoul commited on
Commit
ac24bf9
Β·
verified Β·
1 Parent(s): bb4af56

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +448 -0
app.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import spaces
5
+ import re
6
+
7
+ # Model configuration
8
+ model_name = "HelpingAI/Dhanishtha-2.0-preview"
9
+
10
+ # Global variables for model and tokenizer
11
+ model = None
12
+ tokenizer = None
13
+
14
+ def load_model():
15
+ """Load the model and tokenizer"""
16
+ global model, tokenizer
17
+
18
+ print("Loading tokenizer...")
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ # Ensure pad token is set
22
+ if tokenizer.pad_token is None:
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ print("Loading model...")
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ torch_dtype="auto",
29
+ device_map="auto",
30
+ trust_remote_code=True
31
+ )
32
+
33
+ print("Model loaded successfully!")
34
+
35
+ def format_thinking_text(text):
36
+ """Format text to properly display <think> tags in Gradio with better styling"""
37
+ if not text:
38
+ return text
39
+
40
+ # More sophisticated formatting for thinking blocks
41
+ # Replace <think> and </think> tags with styled markdown
42
+ formatted_text = text
43
+
44
+ # Handle thinking blocks with proper markdown formatting
45
+ thinking_pattern = r'<think>(.*?)</think>'
46
+
47
+ def replace_thinking_block(match):
48
+ thinking_content = match.group(1).strip()
49
+ return f"\n\nπŸ’­ **Thinking Process:**\n\n```\n{thinking_content}\n```\n\n"
50
+
51
+ formatted_text = re.sub(thinking_pattern, replace_thinking_block, formatted_text, flags=re.DOTALL)
52
+
53
+ # Clean up any remaining raw tags that might not have been caught
54
+ formatted_text = re.sub(r'</?think>', '', formatted_text)
55
+
56
+ return formatted_text.strip()
57
+
58
+ @spaces.GPU()
59
+ def generate_response(message, history, max_tokens, temperature, top_p):
60
+ """Generate streaming response without threading"""
61
+ global model, tokenizer
62
+
63
+ if model is None or tokenizer is None:
64
+ yield "Model is still loading. Please wait..."
65
+ return
66
+
67
+ # Prepare conversation history
68
+ messages = []
69
+ for user_msg, assistant_msg in history:
70
+ messages.append({"role": "user", "content": user_msg})
71
+ if assistant_msg:
72
+ messages.append({"role": "assistant", "content": assistant_msg})
73
+
74
+ # Add current message
75
+ messages.append({"role": "user", "content": message})
76
+
77
+ # Apply chat template
78
+ text = tokenizer.apply_chat_template(
79
+ messages,
80
+ tokenize=False,
81
+ add_generation_prompt=True
82
+ )
83
+
84
+ # Tokenize input
85
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
86
+
87
+ try:
88
+ with torch.no_grad():
89
+ # Use transformers streaming with custom approach
90
+ generated_text = ""
91
+ current_input_ids = model_inputs["input_ids"]
92
+ current_attention_mask = model_inputs["attention_mask"]
93
+
94
+ for _ in range(max_tokens):
95
+ # Generate next token
96
+ outputs = model(
97
+ input_ids=current_input_ids,
98
+ attention_mask=current_attention_mask,
99
+ use_cache=True
100
+ )
101
+
102
+ # Get logits for the last token
103
+ logits = outputs.logits[0, -1, :]
104
+
105
+ # Apply temperature
106
+ if temperature != 1.0:
107
+ logits = logits / temperature
108
+
109
+ # Apply top-p sampling
110
+ if top_p < 1.0:
111
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
112
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
113
+ sorted_indices_to_remove = cumulative_probs > top_p
114
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
115
+ sorted_indices_to_remove[0] = 0
116
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
117
+ logits[indices_to_remove] = float('-inf')
118
+
119
+ # Sample next token
120
+ probs = torch.softmax(logits, dim=-1)
121
+ next_token = torch.multinomial(probs, num_samples=1)
122
+
123
+ # Check for EOS token
124
+ if next_token.item() == tokenizer.eos_token_id:
125
+ break
126
+
127
+ # Decode the new token (preserve special tokens like <think>)
128
+ new_token_text = tokenizer.decode(next_token, skip_special_tokens=False)
129
+ generated_text += new_token_text
130
+
131
+ # Format and yield the current text
132
+ formatted_text = format_thinking_text(generated_text)
133
+ yield formatted_text
134
+
135
+ # Update inputs for next iteration
136
+ current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=-1)
137
+ current_attention_mask = torch.cat([current_attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
138
+
139
+ except Exception as e:
140
+ yield f"Error generating response: {str(e)}"
141
+ return
142
+
143
+ # Final yield with complete formatted text
144
+ final_text = format_thinking_text(generated_text) if generated_text else "No response generated."
145
+ yield final_text
146
+
147
+ def chat_interface(message, history, max_tokens, temperature, top_p):
148
+ """Main chat interface with improved streaming"""
149
+ if not message.strip():
150
+ return history, ""
151
+
152
+ # Add user message to history
153
+ history.append([message, ""])
154
+
155
+ # Generate response with streaming
156
+ for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p):
157
+ history[-1][1] = partial_response
158
+ yield history, ""
159
+
160
+ return history, ""
161
+
162
+ # Load model on startup
163
+ print("Initializing model...")
164
+ load_model()
165
+
166
+ # Custom CSS for better styling and thinking blocks
167
+ custom_css = """
168
+ /* Main chatbot styling */
169
+ .chatbot {
170
+ font-size: 14px;
171
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
172
+ }
173
+
174
+ /* Thinking block styling */
175
+ .thinking-block {
176
+ background: linear-gradient(135deg, #f0f8ff 0%, #e6f3ff 100%);
177
+ border-left: 4px solid #4a90e2;
178
+ border-radius: 8px;
179
+ padding: 12px 16px;
180
+ margin: 12px 0;
181
+ font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
182
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
183
+ position: relative;
184
+ }
185
+
186
+ .thinking-block::before {
187
+ content: "πŸ€”";
188
+ position: absolute;
189
+ top: -8px;
190
+ left: 12px;
191
+ background: white;
192
+ padding: 0 4px;
193
+ font-size: 16px;
194
+ }
195
+
196
+ /* Message styling */
197
+ .message {
198
+ padding: 10px 14px;
199
+ margin: 6px 0;
200
+ border-radius: 12px;
201
+ line-height: 1.5;
202
+ }
203
+
204
+ .user-message {
205
+ background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
206
+ margin-left: 15%;
207
+ border-bottom-right-radius: 4px;
208
+ }
209
+
210
+ .assistant-message {
211
+ background: linear-gradient(135deg, #f5f5f5 0%, #eeeeee 100%);
212
+ margin-right: 15%;
213
+ border-bottom-left-radius: 4px;
214
+ }
215
+
216
+ /* Code block styling */
217
+ pre {
218
+ background-color: #f8f9fa;
219
+ border: 1px solid #e9ecef;
220
+ border-radius: 6px;
221
+ padding: 12px;
222
+ overflow-x: auto;
223
+ font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
224
+ font-size: 13px;
225
+ line-height: 1.4;
226
+ }
227
+
228
+ /* Button styling */
229
+ .gradio-button {
230
+ border-radius: 8px;
231
+ font-weight: 500;
232
+ transition: all 0.2s ease;
233
+ }
234
+
235
+ .gradio-button:hover {
236
+ transform: translateY(-1px);
237
+ box-shadow: 0 4px 8px rgba(0,0,0,0.15);
238
+ }
239
+
240
+ /* Input styling */
241
+ .gradio-textbox {
242
+ border-radius: 8px;
243
+ border: 2px solid #e0e0e0;
244
+ transition: border-color 0.2s ease;
245
+ }
246
+
247
+ .gradio-textbox:focus {
248
+ border-color: #4a90e2;
249
+ box-shadow: 0 0 0 3px rgba(74, 144, 226, 0.1);
250
+ }
251
+
252
+ /* Slider styling */
253
+ .gradio-slider {
254
+ margin: 8px 0;
255
+ }
256
+
257
+ /* Examples styling */
258
+ .gradio-examples {
259
+ margin-top: 16px;
260
+ }
261
+
262
+ .gradio-examples .gradio-button {
263
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
264
+ border: 1px solid #dee2e6;
265
+ color: #495057;
266
+ font-size: 13px;
267
+ padding: 8px 12px;
268
+ }
269
+
270
+ .gradio-examples .gradio-button:hover {
271
+ background: linear-gradient(135deg, #e9ecef 0%, #dee2e6 100%);
272
+ color: #212529;
273
+ }
274
+ """
275
+
276
+ # Create Gradio interface
277
+ with gr.Blocks(
278
+ title="πŸ€– Dhanishtha-2.0-preview Chat",
279
+ theme=gr.themes.Soft(),
280
+ css=custom_css
281
+ ) as demo:
282
+ gr.Markdown(
283
+ """
284
+ # πŸ€– Dhanishtha-2.0-preview Chat
285
+
286
+ Chat with the **HelpingAI/Dhanishtha-2.0-preview** model - The world's first LLM designed to think between responses!
287
+
288
+ ### ✨ Key Features:
289
+ - 🧠 **Multi-step Reasoning**: Unlike other LLMs that think once, Dhanishtha can think, rethink, self-evaluate, and refine using multiple `<think>` blocks
290
+ - πŸ”„ **Iterative Thinking**: Watch the model's thought process unfold in real-time
291
+ - πŸ’‘ **Enhanced Problem Solving**: Better reasoning capabilities through structured thinking
292
+
293
+ **Note**: The `<think>` blocks show the model's internal reasoning process and will be displayed in a formatted way below.
294
+ """
295
+ )
296
+
297
+ with gr.Row():
298
+ with gr.Column(scale=4):
299
+ chatbot = gr.Chatbot(
300
+ [],
301
+ elem_id="chatbot",
302
+ bubble_full_width=False,
303
+ height=600,
304
+ show_copy_button=True,
305
+ show_share_button=True,
306
+ avatar_images=("πŸ‘€", "πŸ€–"),
307
+ render_markdown=True,
308
+ latex_delimiters=[
309
+ {"left": "$$", "right": "$$", "display": True},
310
+ {"left": "$", "right": "$", "display": False}
311
+ ]
312
+ )
313
+
314
+ with gr.Row():
315
+ msg = gr.Textbox(
316
+ container=False,
317
+ placeholder="Ask me anything! I'll show you my thinking process...",
318
+ label="Message",
319
+ autofocus=True,
320
+ scale=8,
321
+ lines=1,
322
+ max_lines=5
323
+ )
324
+ send_btn = gr.Button("πŸš€ Send", variant="primary", scale=1, size="lg")
325
+
326
+ with gr.Column(scale=1, min_width=300):
327
+ gr.Markdown("### βš™οΈ Generation Parameters")
328
+
329
+ max_tokens = gr.Slider(
330
+ minimum=50,
331
+ maximum=8192,
332
+ value=2048,
333
+ step=50,
334
+ label="🎯 Max Tokens",
335
+ info="Maximum number of tokens to generate"
336
+ )
337
+
338
+ temperature = gr.Slider(
339
+ minimum=0.1,
340
+ maximum=2.0,
341
+ value=0.7,
342
+ step=0.1,
343
+ label="🌑️ Temperature",
344
+ info="Higher = more creative, Lower = more focused"
345
+ )
346
+
347
+ top_p = gr.Slider(
348
+ minimum=0.1,
349
+ maximum=1.0,
350
+ value=0.9,
351
+ step=0.05,
352
+ label="🎲 Top-p",
353
+ info="Nucleus sampling threshold"
354
+ )
355
+
356
+ with gr.Row():
357
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", scale=1)
358
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
359
+
360
+ gr.Markdown("### πŸ“Š Model Info")
361
+ gr.Markdown(
362
+ """
363
+ **Model**: HelpingAI/Dhanishtha-2.0-preview
364
+ **Type**: Reasoning LLM with thinking blocks
365
+ **Features**: Multi-step reasoning, self-evaluation
366
+ """
367
+ )
368
+
369
+ # Event handlers
370
+ def submit_message(message, history, max_tokens, temperature, top_p):
371
+ """Handle message submission"""
372
+ return chat_interface(message, history, max_tokens, temperature, top_p)
373
+
374
+ def clear_chat():
375
+ """Clear the chat history"""
376
+ return [], ""
377
+
378
+ # Message submission events
379
+ msg.submit(
380
+ submit_message,
381
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
382
+ outputs=[chatbot, msg],
383
+ concurrency_limit=1,
384
+ show_progress="minimal"
385
+ )
386
+
387
+ send_btn.click(
388
+ submit_message,
389
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
390
+ outputs=[chatbot, msg],
391
+ concurrency_limit=1,
392
+ show_progress="minimal"
393
+ )
394
+
395
+ # Clear chat event
396
+ clear_btn.click(
397
+ clear_chat,
398
+ outputs=[chatbot, msg],
399
+ show_progress=False
400
+ )
401
+
402
+ # Example prompts section
403
+ with gr.Row():
404
+ gr.Examples(
405
+ examples=[
406
+ ["Hello! Can you introduce yourself and show me how you think?"],
407
+ ["Solve this step by step: What is 15% of 240?"],
408
+ ["Explain quantum entanglement in simple terms"],
409
+ ["Write a short Python function to find the factorial of a number"],
410
+ ["What are the pros and cons of renewable energy?"],
411
+ ["Help me understand the difference between AI and machine learning"],
412
+ ["Create a haiku about artificial intelligence"],
413
+ ["Explain why the sky is blue using physics principles"]
414
+ ],
415
+ inputs=msg,
416
+ label="πŸ’‘ Example Prompts - Try these to see the thinking process!",
417
+ examples_per_page=4
418
+ )
419
+
420
+ # Footer with information
421
+ gr.Markdown(
422
+ """
423
+ ---
424
+ ### πŸ”§ Technical Details
425
+ - **Model**: HelpingAI/Dhanishtha-2.0-preview
426
+ - **Framework**: Transformers + Gradio
427
+ - **Features**: Real-time streaming, thinking process visualization, custom sampling
428
+ - **Reasoning**: Multi-step thinking with `<think>` blocks for transparent AI reasoning
429
+
430
+ **Note**: This interface streams responses token by token and formats thinking blocks for better readability.
431
+ The model's internal reasoning process is displayed in formatted code blocks.
432
+
433
+ ---
434
+ *Built with ❀️ using Gradio and Transformers*
435
+ """
436
+ )
437
+
438
+ if __name__ == "__main__":
439
+ demo.queue(
440
+ max_size=20,
441
+ default_concurrency_limit=1
442
+ ).launch(
443
+ server_name="0.0.0.0",
444
+ server_port=7860,
445
+ share=False,
446
+ show_error=True,
447
+ quiet=False
448
+ )