Shilpaj commited on
Commit
921561a
·
1 Parent(s): 100d65e

Feat: Zero GPU for inference

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -6,6 +6,7 @@ from dataclasses import dataclass
6
  import torch.nn as nn
7
  import math
8
  import inspect
 
9
 
10
  # Configuration class (same as in training)
11
  @dataclass
@@ -129,7 +130,8 @@ def load_model():
129
  return model, device
130
 
131
  # Text generation function
132
- def generate_text(prompt, num_tokens, model, device, temperature=0.8):
 
133
  enc = tiktoken.get_encoding('gpt2')
134
  x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=device)
135
 
@@ -148,13 +150,9 @@ def generate_text(prompt, num_tokens, model, device, temperature=0.8):
148
  # Load the model globally
149
  model, device = load_model()
150
 
151
- # Gradio interface
152
- def gradio_interface(prompt, num_tokens, temperature):
153
- return generate_text(prompt, num_tokens, model, device, temperature)
154
-
155
  # Create the Gradio interface
156
- iface = gr.Interface(
157
- fn=gradio_interface,
158
  inputs=[
159
  gr.Textbox(label="Enter your prompt", value="Once upon a time"),
160
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Number of tokens to generate"),
@@ -163,7 +161,9 @@ iface = gr.Interface(
163
  outputs=gr.Textbox(label="Generated Text"),
164
  title="NanoGPT Text Generator",
165
  description="Generate Shakespeare-style text using a trained NanoGPT model",
 
 
166
  )
167
 
168
  if __name__ == "__main__":
169
- iface.launch()
 
6
  import torch.nn as nn
7
  import math
8
  import inspect
9
+ import spaces
10
 
11
  # Configuration class (same as in training)
12
  @dataclass
 
130
  return model, device
131
 
132
  # Text generation function
133
+ @spaces.gpu(enable_queue=True)
134
+ def generate_text(prompt, num_tokens, temperature=0.8):
135
  enc = tiktoken.get_encoding('gpt2')
136
  x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=device)
137
 
 
150
  # Load the model globally
151
  model, device = load_model()
152
 
 
 
 
 
153
  # Create the Gradio interface
154
+ demo = gr.Interface(
155
+ fn=generate_text,
156
  inputs=[
157
  gr.Textbox(label="Enter your prompt", value="Once upon a time"),
158
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Number of tokens to generate"),
 
161
  outputs=gr.Textbox(label="Generated Text"),
162
  title="NanoGPT Text Generator",
163
  description="Generate Shakespeare-style text using a trained NanoGPT model",
164
+ allow_flagging="never",
165
+ cache_examples=True
166
  )
167
 
168
  if __name__ == "__main__":
169
+ demo.launch()