Abhaykoul commited on
Commit
e6dc671
·
verified ·
1 Parent(s): 0ac3d6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
+ import threading
5
+ import queue
6
+ import time
7
+
8
+ # Model configuration
9
+ model_name = "HelpingAI/Dhanishtha-2.0-preview"
10
+
11
+ # Global variables for model and tokenizer
12
+ model = None
13
+ tokenizer = None
14
+
15
+ def load_model():
16
+ """Load the model and tokenizer"""
17
+ global model, tokenizer
18
+
19
+ print("Loading tokenizer...")
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+
22
+ print("Loading model...")
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_name,
25
+ torch_dtype="auto",
26
+ device_map="auto",
27
+ trust_remote_code=True
28
+ )
29
+
30
+ print("Model loaded successfully!")
31
+
32
+ class GradioTextStreamer(TextStreamer):
33
+ """Custom TextStreamer for Gradio integration"""
34
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
35
+ super().__init__(tokenizer, skip_prompt, skip_special_tokens)
36
+ self.text_queue = queue.Queue()
37
+ self.generated_text = ""
38
+
39
+ def on_finalized_text(self, text: str, stream_end: bool = False):
40
+ """Called when text is finalized"""
41
+ self.generated_text += text
42
+ self.text_queue.put(text)
43
+ if stream_end:
44
+ self.text_queue.put(None)
45
+
46
+ def get_generated_text(self):
47
+ """Get all generated text so far"""
48
+ return self.generated_text
49
+
50
+ def reset(self):
51
+ """Reset the streamer"""
52
+ self.generated_text = ""
53
+ # Clear the queue
54
+ while not self.text_queue.empty():
55
+ try:
56
+ self.text_queue.get_nowait()
57
+ except queue.Empty:
58
+ break
59
+
60
+ def generate_response(message, history, max_tokens, temperature, top_p):
61
+ """Generate streaming response"""
62
+ global model, tokenizer
63
+
64
+ if model is None or tokenizer is None:
65
+ yield "Model is still loading. Please wait..."
66
+ return
67
+
68
+ # Prepare conversation history
69
+ messages = []
70
+ for user_msg, assistant_msg in history:
71
+ messages.append({"role": "user", "content": user_msg})
72
+ if assistant_msg:
73
+ messages.append({"role": "assistant", "content": assistant_msg})
74
+
75
+ # Add current message
76
+ messages.append({"role": "user", "content": message})
77
+
78
+ # Apply chat template
79
+ text = tokenizer.apply_chat_template(
80
+ messages,
81
+ tokenize=False,
82
+ add_generation_prompt=True
83
+ )
84
+
85
+ # Tokenize input
86
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
87
+
88
+ # Create and setup streamer
89
+ streamer = GradioTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
90
+ streamer.reset()
91
+
92
+ # Start generation in a separate thread
93
+ generation_kwargs = {
94
+ **model_inputs,
95
+ "max_new_tokens": max_tokens,
96
+ "temperature": temperature,
97
+ "top_p": top_p,
98
+ "do_sample": True,
99
+ "pad_token_id": tokenizer.eos_token_id,
100
+ "streamer": streamer,
101
+ "return_dict_in_generate": True
102
+ }
103
+
104
+ # Run generation in thread
105
+ def generate():
106
+ try:
107
+ with torch.no_grad():
108
+ model.generate(**generation_kwargs)
109
+ except Exception as e:
110
+ streamer.text_queue.put(f"Error: {str(e)}")
111
+ streamer.text_queue.put(None)
112
+
113
+ thread = threading.Thread(target=generate)
114
+ thread.start()
115
+
116
+ # Stream the results
117
+ generated_text = ""
118
+ while True:
119
+ try:
120
+ new_text = streamer.text_queue.get(timeout=30)
121
+ if new_text is None:
122
+ break
123
+ generated_text += new_text
124
+ yield generated_text
125
+ except queue.Empty:
126
+ break
127
+
128
+ thread.join(timeout=1)
129
+
130
+ # Final yield with complete text
131
+ if generated_text:
132
+ yield generated_text
133
+ else:
134
+ yield "No response generated."
135
+
136
+ def chat_interface(message, history, max_tokens, temperature, top_p):
137
+ """Main chat interface"""
138
+ if not message.strip():
139
+ return history, ""
140
+
141
+ # Add user message to history
142
+ history.append([message, ""])
143
+
144
+ # Generate response
145
+ for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p):
146
+ history[-1][1] = partial_response
147
+ yield history, ""
148
+
149
+ return history, ""
150
+
151
+ # Load model on startup
152
+ print("Initializing model...")
153
+ load_model()
154
+
155
+ # Create Gradio interface
156
+ with gr.Blocks(title="Dhanishtha-2.0-preview Chat", theme=gr.themes.Soft()) as demo:
157
+ gr.Markdown(
158
+ """
159
+ # 🤖 Dhanishtha-2.0-preview Chat
160
+
161
+ Chat with the **HelpingAI/Dhanishtha-2.0-preview** model!
162
+ """
163
+ )
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=4):
167
+ chatbot = gr.Chatbot(
168
+ [],
169
+ elem_id="chatbot",
170
+ bubble_full_width=False,
171
+ height=500,
172
+ show_copy_button=True
173
+ )
174
+
175
+ with gr.Row():
176
+ msg = gr.Textbox(
177
+ container=False,
178
+ placeholder="Type your message here...",
179
+ label="Message",
180
+ autofocus=True,
181
+ scale=7
182
+ )
183
+ send_btn = gr.Button("Send", variant="primary", scale=1)
184
+
185
+ with gr.Column(scale=1):
186
+ gr.Markdown("### ⚙️ Parameters")
187
+
188
+ max_tokens = gr.Slider(
189
+ minimum=1,
190
+ maximum=4096,
191
+ value=2048,
192
+ step=1,
193
+ label="Max Tokens",
194
+ info="Maximum number of tokens to generate"
195
+ )
196
+
197
+ temperature = gr.Slider(
198
+ minimum=0.1,
199
+ maximum=2.0,
200
+ value=0.7,
201
+ step=0.1,
202
+ label="Temperature",
203
+ info="Controls randomness in generation"
204
+ )
205
+
206
+ top_p = gr.Slider(
207
+ minimum=0.1,
208
+ maximum=1.0,
209
+ value=0.9,
210
+ step=0.05,
211
+ label="Top-p",
212
+ info="Controls diversity of generation"
213
+ )
214
+
215
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
216
+
217
+ # Event handlers
218
+ msg.submit(
219
+ chat_interface,
220
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
221
+ outputs=[chatbot, msg],
222
+ concurrency_limit=1
223
+ )
224
+
225
+ send_btn.click(
226
+ chat_interface,
227
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
228
+ outputs=[chatbot, msg],
229
+ concurrency_limit=1
230
+ )
231
+
232
+ clear_btn.click(
233
+ lambda: ([], ""),
234
+ outputs=[chatbot, msg]
235
+ )
236
+
237
+ # Example prompts
238
+ gr.Examples(
239
+ examples=[
240
+ ["Hello! Who are you?"],
241
+ ["Explain quantum computing in simple terms"],
242
+ ["Write a short story about a robot learning to paint"],
243
+ ["What are the benefits of renewable energy?"],
244
+ ["Help me write a Python function to sort a list"]
245
+ ],
246
+ inputs=msg,
247
+ label="💡 Example Prompts"
248
+ )
249
+
250
+ if __name__ == "__main__":
251
+ demo.queue(max_size=20).launch()