Update app.py
Browse files
app.py
CHANGED
|
@@ -11,12 +11,18 @@ else:
|
|
| 11 |
print("Using CPU")
|
| 12 |
|
| 13 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
| 14 |
-
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype="auto")
|
| 15 |
-
|
| 16 |
-
model.to(device)
|
| 17 |
|
| 18 |
|
| 19 |
-
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
| 21 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 22 |
|
|
@@ -49,6 +55,8 @@ repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, inter
|
|
| 49 |
|
| 50 |
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
|
| 51 |
|
|
|
|
|
|
|
| 52 |
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
|
| 53 |
|
| 54 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
|
@@ -69,7 +77,7 @@ examples = [
|
|
| 69 |
|
| 70 |
gr.Interface(
|
| 71 |
fn=generate,
|
| 72 |
-
inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed],
|
| 73 |
outputs=gr.Textbox(label="Better Prompt"),
|
| 74 |
title="SuperPrompt-v1",
|
| 75 |
description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',
|
|
|
|
| 11 |
print("Using CPU")
|
| 12 |
|
| 13 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
+
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
| 17 |
+
|
| 18 |
+
if model_precision_type == "fp16":
|
| 19 |
+
dtype = torch.float16
|
| 20 |
+
elif model_precision_type == "fp32":
|
| 21 |
+
dtype = torch.float32
|
| 22 |
+
|
| 23 |
+
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=dtype)
|
| 24 |
+
model.to(device)
|
| 25 |
+
|
| 26 |
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
| 27 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 28 |
|
|
|
|
| 55 |
|
| 56 |
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
|
| 57 |
|
| 58 |
+
model_precision_type = gr.Dropdown(["fp16", "fp32"], value="fp16", label="Model Precision Type", info="The precision type to load the model, like fp16 which is faster, or fp32 which is more precise but more resource consuming")
|
| 59 |
+
|
| 60 |
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
|
| 61 |
|
| 62 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
|
|
|
| 77 |
|
| 78 |
gr.Interface(
|
| 79 |
fn=generate,
|
| 80 |
+
inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed],
|
| 81 |
outputs=gr.Textbox(label="Better Prompt"),
|
| 82 |
title="SuperPrompt-v1",
|
| 83 |
description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',
|