cakemus commited on
Commit
ed5d53c
·
1 Parent(s): dee82dd
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -1,13 +1,17 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Load an available model for text generation
5
- model = pipeline("text-generation", model="gpt2")
 
 
 
 
6
 
7
  def generate_text(prompt):
8
  return model(prompt, max_length=50)[0]["generated_text"]
9
 
10
- # Custom interface with a textbox and description
11
  interface = gr.Interface(
12
  fn=generate_text,
13
  inputs=gr.Textbox(label="Enter your prompt here"),
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
 
5
+ # Check if a GPU is available
6
+ device = 0 if torch.cuda.is_available() else -1
7
+ print("Using GPU" if device == 0 else "Using CPU")
8
+
9
+ # Load the model on the GPU if available
10
+ model = pipeline("text-generation", model="gpt2", device=device)
11
 
12
  def generate_text(prompt):
13
  return model(prompt, max_length=50)[0]["generated_text"]
14
 
 
15
  interface = gr.Interface(
16
  fn=generate_text,
17
  inputs=gr.Textbox(label="Enter your prompt here"),