Update app.py
Browse files
app.py
CHANGED
@@ -11,12 +11,18 @@ else:
|
|
11 |
print("Using CPU")
|
12 |
|
13 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
14 |
-
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype="auto")
|
15 |
-
|
16 |
-
model.to(device)
|
17 |
|
18 |
|
19 |
-
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
21 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
22 |
|
@@ -49,6 +55,8 @@ repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, inter
|
|
49 |
|
50 |
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
|
51 |
|
|
|
|
|
52 |
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
|
53 |
|
54 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
@@ -69,7 +77,7 @@ examples = [
|
|
69 |
|
70 |
gr.Interface(
|
71 |
fn=generate,
|
72 |
-
inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed],
|
73 |
outputs=gr.Textbox(label="Better Prompt"),
|
74 |
title="SuperPrompt-v1",
|
75 |
description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',
|
|
|
11 |
print("Using CPU")
|
12 |
|
13 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
+
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
17 |
+
|
18 |
+
if model_precision_type == "fp16":
|
19 |
+
dtype = torch.float16
|
20 |
+
elif model_precision_type == "fp32":
|
21 |
+
dtype = torch.float32
|
22 |
+
|
23 |
+
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=dtype)
|
24 |
+
model.to(device)
|
25 |
+
|
26 |
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
27 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
28 |
|
|
|
55 |
|
56 |
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
|
57 |
|
58 |
+
model_precision_type = gr.Dropdown(["fp16", "fp32"], value="fp16", label="Model Precision Type", info="The precision type to load the model, like fp16 which is faster, or fp32 which is more precise but more resource consuming")
|
59 |
+
|
60 |
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
|
61 |
|
62 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
|
|
77 |
|
78 |
gr.Interface(
|
79 |
fn=generate,
|
80 |
+
inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed],
|
81 |
outputs=gr.Textbox(label="Better Prompt"),
|
82 |
title="SuperPrompt-v1",
|
83 |
description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',
|