Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,14 +4,7 @@ import random
|
|
| 4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 5 |
|
| 6 |
def load_model(model_path, dtype):
|
| 7 |
-
|
| 8 |
-
torch_dtype = torch.float32
|
| 9 |
-
elif dtype == "fp16":
|
| 10 |
-
torch_dtype = torch.float16
|
| 11 |
-
else:
|
| 12 |
-
raise ValueError("Invalid dtype. Only 'fp32' or 'fp16' are supported.")
|
| 13 |
-
|
| 14 |
-
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch_dtype)
|
| 15 |
return model
|
| 16 |
|
| 17 |
def generate(
|
|
@@ -120,7 +113,7 @@ additional_inputs = [
|
|
| 120 |
info="A starting point to initiate the generation process"
|
| 121 |
),
|
| 122 |
gr.Radio(
|
| 123 |
-
choices=["fp32", "fp16"],
|
| 124 |
value="fp16",
|
| 125 |
label="Model Precision",
|
| 126 |
info="fp32 is more precised, fp16 is faster and less memory consuming",
|
|
|
|
| 4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 5 |
|
| 6 |
def load_model(model_path, dtype):
|
| 7 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
return model
|
| 9 |
|
| 10 |
def generate(
|
|
|
|
| 113 |
info="A starting point to initiate the generation process"
|
| 114 |
),
|
| 115 |
gr.Radio(
|
| 116 |
+
choices=[("fp32", torch.float32), ("fp16", torch.float16)],
|
| 117 |
value="fp16",
|
| 118 |
label="Model Precision",
|
| 119 |
info="fp32 is more precised, fp16 is faster and less memory consuming",
|