sagar007 commited on
Commit
c78be87
·
verified ·
1 Parent(s): a075fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -13
app.py CHANGED
@@ -120,6 +120,15 @@ model = load_model('gpt_model.pth') # Replace with the actual path to your .pt
120
  enc = tiktoken.get_encoding('gpt2')
121
 
122
  # Improved text generation function
 
 
 
 
 
 
 
 
 
123
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
124
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
125
  generated = []
@@ -128,32 +137,70 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
128
  for _ in range(max_length):
129
  outputs, _ = model(input_ids)
130
  next_token_logits = outputs[:, -1, :]
131
-
132
- # Apply temperature
133
  next_token_logits = next_token_logits / temperature
134
-
135
- # Apply top-k filtering
136
  top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
137
  next_token_probs = F.softmax(top_k_logits, dim=-1)
138
-
139
- # Sample from the filtered distribution
140
  next_token_index = torch.multinomial(next_token_probs, num_samples=1)
141
  next_token = top_k_indices.gather(-1, next_token_index)
142
 
143
  input_ids = torch.cat([input_ids, next_token], dim=-1)
144
  generated.append(next_token.item())
145
 
146
- # Stop if we generate a newline, but only after generating at least 20 tokens
 
147
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
148
  break
149
-
150
- generated_text = enc.decode(generated)
151
- return prompt + generated_text
152
 
153
  # Gradio interface
154
  def gradio_generate(prompt, max_length, temperature, top_k):
155
  return generate_text(prompt, max_length, temperature, top_k)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  iface = gr.Interface(
158
  fn=gradio_generate,
159
  inputs=[
@@ -162,9 +209,12 @@ iface = gr.Interface(
162
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
163
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
164
  ],
165
- outputs=gr.Textbox(label="Generated Text"),
166
- title="GPT Text Generator",
167
- description="Enter a prompt and adjust parameters to generate text using a fine-tuned GPT model."
 
 
 
168
  )
169
 
170
  # Launch the app
 
120
  enc = tiktoken.get_encoding('gpt2')
121
 
122
  # Improved text generation function
123
+ import torch
124
+ import torch.nn as nn
125
+ from torch.nn import functional as F
126
+ import tiktoken
127
+ import gradio as gr
128
+
129
+ # [Your existing model code remains unchanged]
130
+
131
+ # Modified text generation function to yield tokens
132
  def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
133
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
134
  generated = []
 
137
  for _ in range(max_length):
138
  outputs, _ = model(input_ids)
139
  next_token_logits = outputs[:, -1, :]
 
 
140
  next_token_logits = next_token_logits / temperature
 
 
141
  top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
142
  next_token_probs = F.softmax(top_k_logits, dim=-1)
 
 
143
  next_token_index = torch.multinomial(next_token_probs, num_samples=1)
144
  next_token = top_k_indices.gather(-1, next_token_index)
145
 
146
  input_ids = torch.cat([input_ids, next_token], dim=-1)
147
  generated.append(next_token.item())
148
 
149
+ yield enc.decode([next_token.item()])
150
+
151
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
152
  break
 
 
 
153
 
154
  # Gradio interface
155
  def gradio_generate(prompt, max_length, temperature, top_k):
156
  return generate_text(prompt, max_length, temperature, top_k)
157
 
158
+ # Custom CSS for the animation effect
159
+ custom_css = """
160
+ <style>
161
+ .output-box {
162
+ border: 1px solid #e0e0e0;
163
+ border-radius: 8px;
164
+ padding: 20px;
165
+ font-family: Arial, sans-serif;
166
+ line-height: 1.6;
167
+ height: 300px;
168
+ overflow-y: auto;
169
+ background-color: #f9f9f9;
170
+ }
171
+ .blinking-cursor {
172
+ display: inline-block;
173
+ width: 10px;
174
+ height: 20px;
175
+ background-color: #333;
176
+ animation: blink 0.7s infinite;
177
+ }
178
+ @keyframes blink {
179
+ 0% { opacity: 0; }
180
+ 50% { opacity: 1; }
181
+ 100% { opacity: 0; }
182
+ }
183
+ </style>
184
+ """
185
+
186
+ # JavaScript for the typing animation
187
+ js_code = """
188
+ function typeText(text, element) {
189
+ let index = 0;
190
+ element.innerHTML = '';
191
+ function type() {
192
+ if (index < text.length) {
193
+ element.innerHTML += text[index];
194
+ index++;
195
+ setTimeout(type, 50); // Adjust typing speed here
196
+ } else {
197
+ element.innerHTML += '<span class="blinking-cursor"></span>';
198
+ }
199
+ }
200
+ type();
201
+ }
202
+ """
203
+
204
  iface = gr.Interface(
205
  fn=gradio_generate,
206
  inputs=[
 
209
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
210
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
211
  ],
212
+ outputs=gr.HTML(label="Generated Text"),
213
+ title="Animated GPT Text Generator",
214
+ description="Enter a prompt and adjust parameters to generate text using a fine-tuned GPT model.",
215
+ css=custom_css,
216
+ js=js_code,
217
+ live=True
218
  )
219
 
220
  # Launch the app