Nick088 commited on
Commit
7e99bc1
·
verified ·
1 Parent(s): 2319f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -11,12 +11,18 @@ else:
11
  print("Using CPU")
12
 
13
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
14
- model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype="auto")
15
-
16
- model.to(device)
17
 
18
 
19
- def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed):
 
 
 
 
 
 
 
 
 
20
  input_text = f"Expand the following prompt to add more detail: {your_prompt}"
21
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
22
 
@@ -49,6 +55,8 @@ repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, inter
49
 
50
  temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
51
 
 
 
52
  top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
53
 
54
  top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
@@ -69,7 +77,7 @@ examples = [
69
 
70
  gr.Interface(
71
  fn=generate,
72
- inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed],
73
  outputs=gr.Textbox(label="Better Prompt"),
74
  title="SuperPrompt-v1",
75
  description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',
 
11
  print("Using CPU")
12
 
13
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
 
 
 
14
 
15
 
16
+ def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
17
+
18
+ if model_precision_type == "fp16":
19
+ dtype = torch.float16
20
+ elif model_precision_type == "fp32":
21
+ dtype = torch.float32
22
+
23
+ model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=dtype)
24
+ model.to(device)
25
+
26
  input_text = f"Expand the following prompt to add more detail: {your_prompt}"
27
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
28
 
 
55
 
56
  temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
57
 
58
+ model_precision_type = gr.Dropdown(["fp16", "fp32"], value="fp16", label="Model Precision Type", info="The precision type to load the model, like fp16 which is faster, or fp32 which is more precise but more resource consuming")
59
+
60
  top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
61
 
62
  top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
 
77
 
78
  gr.Interface(
79
  fn=generate,
80
+ inputs=[your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed],
81
  outputs=gr.Textbox(label="Better Prompt"),
82
  title="SuperPrompt-v1",
83
  description='Make your prompts more detailed! <br> <a href="https://huggingface.co/roborovski/superprompt-v1">Model used</a> <br> <a href="https://brianfitzgerald.xyz/prompt-augmentation/">Model Blog</a> <br> Task Prefix: "Expand the following prompt to add more detail:" is already setted! <br> Hugging Face Space made by [Nick088](https://linktr.ee/Nick088)',