Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import random
|
4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
5 |
|
6 |
if torch.cuda.is_available():
|
7 |
device = "cuda"
|
@@ -12,9 +13,8 @@ else:
|
|
12 |
|
13 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
14 |
|
15 |
-
|
16 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
17 |
-
|
18 |
if model_precision_type == "fp16":
|
19 |
dtype = torch.float16
|
20 |
elif model_precision_type == "fp32":
|
|
|
2 |
import torch
|
3 |
import random
|
4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
5 |
+
import spaces
|
6 |
|
7 |
if torch.cuda.is_available():
|
8 |
device = "cuda"
|
|
|
13 |
|
14 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
15 |
|
16 |
+
@spaces.GPU()
|
17 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
|
|
18 |
if model_precision_type == "fp16":
|
19 |
dtype = torch.float16
|
20 |
elif model_precision_type == "fp32":
|