cakemus commited on
Commit
36bfcb2
·
1 Parent(s): ed5d53c
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -1,17 +1,15 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- import torch
4
-
5
- # Check if a GPU is available
6
- device = 0 if torch.cuda.is_available() else -1
7
- print("Using GPU" if device == 0 else "Using CPU")
8
-
9
- # Load the model on the GPU if available
10
- model = pipeline("text-generation", model="gpt2", device=device)
11
 
 
 
12
  def generate_text(prompt):
 
 
13
  return model(prompt, max_length=50)[0]["generated_text"]
14
 
 
15
  interface = gr.Interface(
16
  fn=generate_text,
17
  inputs=gr.Textbox(label="Enter your prompt here"),
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from spaces import GPU # Import the GPU decorator for ZeroGPU
 
 
 
 
 
 
 
4
 
5
+ # Decorate the function to indicate it needs GPU resources
6
+ @GPU
7
  def generate_text(prompt):
8
+ # Load the model within the function so that it only runs on GPU when the function is called
9
+ model = pipeline("text-generation", model="gpt2", device=0)
10
  return model(prompt, max_length=50)[0]["generated_text"]
11
 
12
+ # Create the Gradio interface
13
  interface = gr.Interface(
14
  fn=generate_text,
15
  inputs=gr.Textbox(label="Enter your prompt here"),