Nick088 commited on
Commit
5709663
·
verified ·
1 Parent(s): fb4ba64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -11,12 +11,20 @@ else:
11
  print("Using CPU")
12
 
13
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
 
14
 
15
- model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
16
  model.to(device)
17
-
18
- def generate(prompt, model_precision_type, max_new_tokens, repetition_penalty, temperature, top_p, top_k, seed):
19
-
 
 
 
 
 
 
 
 
20
  input_text = f"Expand the following prompt to add more detail: {prompt}"
21
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
22
 
@@ -25,7 +33,7 @@ def generate(prompt, model_precision_type, max_new_tokens, repetition_penalty, t
25
  torch.manual_seed(seed)
26
  else:
27
  torch.manual_seed(seed)
28
-
29
  outputs = model.generate(
30
  input_ids,
31
  max_new_tokens=max_new_tokens,
@@ -37,15 +45,16 @@ def generate(prompt, model_precision_type, max_new_tokens, repetition_penalty, t
37
  )
38
 
39
  better_prompt = tokenizer.decode(outputs[0])
40
- better_prompt = better_prompt.replace("<pad>", "").replace("<|endoftext|>", "")
41
  return better_prompt
42
 
 
43
  prompt = gr.Textbox(label="Prompt", interactive=True)
44
 
45
  max_new_tokens = gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output")
46
-
47
  repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself")
48
-
49
  temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
50
 
51
  top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
@@ -55,7 +64,15 @@ top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, lab
55
  seed = gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
56
 
57
  examples = [
58
- ["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, 1, 50, 42]
 
 
 
 
 
 
 
 
59
  ]
60
 
61
  gr.Interface(
 
11
  print("Using CPU")
12
 
13
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
14
+ model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype="auto")
15
 
 
16
  model.to(device)
17
+
18
+
19
+ def generate(
20
+ prompt,
21
+ max_new_tokens,
22
+ repetition_penalty,
23
+ temperature,
24
+ top_p,
25
+ top_k,
26
+ seed
27
+ ):
28
  input_text = f"Expand the following prompt to add more detail: {prompt}"
29
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
30
 
 
33
  torch.manual_seed(seed)
34
  else:
35
  torch.manual_seed(seed)
36
+
37
  outputs = model.generate(
38
  input_ids,
39
  max_new_tokens=max_new_tokens,
 
45
  )
46
 
47
  better_prompt = tokenizer.decode(outputs[0])
48
+ better_prompt = better_prompt.replace("<pad>", "").replace("</s>", "")
49
  return better_prompt
50
 
51
+
52
  prompt = gr.Textbox(label="Prompt", interactive=True)
53
 
54
  max_new_tokens = gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output")
55
+
56
  repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself")
57
+
58
  temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
59
 
60
  top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
 
64
  seed = gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
65
 
66
  examples = [
67
+ [
68
+ "A storefront with 'Text to Image' written on it.",
69
+ 512,
70
+ 1.2,
71
+ 0.5,
72
+ 1,
73
+ 50,
74
+ 42,
75
+ ]
76
  ]
77
 
78
  gr.Interface(