ford442 commited on
Commit
ffdb810
·
verified ·
1 Parent(s): 0956914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -57
app.py CHANGED
@@ -122,55 +122,59 @@ def infer(
122
  ):
123
  seed = random.randint(0, MAX_SEED)
124
  generator = torch.Generator(device='cuda').manual_seed(seed)
125
- system_prompt_rewrite = (
126
- "You are an AI assistant that rewrites image prompts to be more descriptive and detailed."
127
- )
128
- user_prompt_rewrite = (
129
- "Rewrite this prompt to be more descriptive and detailed and only return the rewritten text: "
130
- )
131
- user_prompt_rewrite_2 = (
132
- "Rephrase this scene to have more elaborate details: "
133
- )
134
- input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
135
- input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
136
- print("-- got prompt --")
137
- # Encode the input text and include the attention mask
138
- encoded_inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=True)
139
- encoded_inputs_2 = tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True)
140
- # Ensure all values are on the correct device
141
- input_ids = encoded_inputs["input_ids"].to(device)
142
- input_ids_2 = encoded_inputs_2["input_ids"].to(device)
143
- attention_mask = encoded_inputs["attention_mask"].to(device)
144
- attention_mask_2 = encoded_inputs_2["attention_mask"].to(device)
145
- print("-- tokenize prompt --")
146
- # Google T5
147
- #input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
148
- outputs = model.generate(
149
- input_ids=input_ids,
150
- attention_mask=attention_mask,
151
- max_new_tokens=512,
152
- temperature=0.2,
153
- top_p=0.9,
154
- do_sample=True,
155
- )
156
- outputs_2 = model.generate(
157
- input_ids=input_ids_2,
158
- attention_mask=attention_mask_2,
159
- max_new_tokens=65,
160
- temperature=0.2,
161
- top_p=0.9,
162
- do_sample=True,
163
- )
164
- # Use the encoded tensor 'text_inputs' here
165
- enhanced_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
166
- enhanced_prompt_2 = tokenizer.decode(outputs_2[0], skip_special_tokens=True)
167
- print('-- generated prompt --')
168
- enhanced_prompt = filter_text(enhanced_prompt,prompt)
169
- enhanced_prompt_2 = filter_text(enhanced_prompt_2,prompt)
170
- print('-- filtered prompt --')
171
- print(enhanced_prompt)
172
- print('-- filtered prompt 2 --')
173
- print(enhanced_prompt_2)
 
 
 
 
174
  if latent_file: # Check if a latent file is provided
175
  # initial_latents = pipe.prepare_latents(
176
  # batch_size=1,
@@ -216,13 +220,19 @@ def infer(
216
  max_sequence_length=512
217
  ).images[0]
218
  print('-- got image --')
219
- #sd35_image = pipe.vae.decode(sd_image / 0.18215).sample
220
- #sd35_image = sdxl_image.cpu().permute(0, 2, 3, 1).float().detach().numpy()
221
- #sd35_image = (sdxl_image * 255).round().astype("uint8")
222
- #image_pil = Image.fromarray(sd35_image[0])
223
- sd35_path = f"sd35_{seed}.png"
224
- sd_image.save(sd35_path,optimize=False,compress_level=0)
225
- upload_to_ftp(sd35_path)
 
 
 
 
 
 
226
 
227
  # Convert the generated image to a tensor
228
  #generated_image_tensor = torch.tensor([np.array(sd_image).transpose(2, 0, 1)]).to('cuda') / 255.0
@@ -293,7 +303,6 @@ def repeat_infer(
293
  i += 1
294
  return result, seed, image_path, enhanced_prompt
295
 
296
-
297
  with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
298
  with gr.Column(elem_id="col-container"):
299
  gr.Markdown(" # Text-to-Text-to-Image StableDiffusion 3.5 Medium (with refine)")
 
122
  ):
123
  seed = random.randint(0, MAX_SEED)
124
  generator = torch.Generator(device='cuda').manual_seed(seed)
125
+ if expanded:
126
+ system_prompt_rewrite = (
127
+ "You are an AI assistant that rewrites image prompts to be more descriptive and detailed."
128
+ )
129
+ user_prompt_rewrite = (
130
+ "Rewrite this prompt to be more descriptive and detailed and only return the rewritten text: "
131
+ )
132
+ user_prompt_rewrite_2 = (
133
+ "Rephrase this scene to have more elaborate details: "
134
+ )
135
+ input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
136
+ input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
137
+ print("-- got prompt --")
138
+ # Encode the input text and include the attention mask
139
+ encoded_inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=True)
140
+ encoded_inputs_2 = tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True)
141
+ # Ensure all values are on the correct device
142
+ input_ids = encoded_inputs["input_ids"].to(device)
143
+ input_ids_2 = encoded_inputs_2["input_ids"].to(device)
144
+ attention_mask = encoded_inputs["attention_mask"].to(device)
145
+ attention_mask_2 = encoded_inputs_2["attention_mask"].to(device)
146
+ print("-- tokenize prompt --")
147
+ # Google T5
148
+ #input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
149
+ outputs = model.generate(
150
+ input_ids=input_ids,
151
+ attention_mask=attention_mask,
152
+ max_new_tokens=512,
153
+ temperature=0.2,
154
+ top_p=0.9,
155
+ do_sample=True,
156
+ )
157
+ outputs_2 = model.generate(
158
+ input_ids=input_ids_2,
159
+ attention_mask=attention_mask_2,
160
+ max_new_tokens=65,
161
+ temperature=0.2,
162
+ top_p=0.9,
163
+ do_sample=True,
164
+ )
165
+ # Use the encoded tensor 'text_inputs' here
166
+ enhanced_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
+ enhanced_prompt_2 = tokenizer.decode(outputs_2[0], skip_special_tokens=True)
168
+ print('-- generated prompt --')
169
+ enhanced_prompt = filter_text(enhanced_prompt,prompt)
170
+ enhanced_prompt_2 = filter_text(enhanced_prompt_2,prompt)
171
+ print('-- filtered prompt --')
172
+ print(enhanced_prompt)
173
+ print('-- filtered prompt 2 --')
174
+ print(enhanced_prompt_2)
175
+ else:
176
+ enhanced_prompt = prompt
177
+ enhanced_prompt_2 = prompt
178
  if latent_file: # Check if a latent file is provided
179
  # initial_latents = pipe.prepare_latents(
180
  # batch_size=1,
 
220
  max_sequence_length=512
221
  ).images[0]
222
  print('-- got image --')
223
+
224
+ sd35_image_image = pipe.vae.decode(sd_image / 0.18215).sample
225
+ sd35_image = sd35_image.cpu().permute(0, 2, 3, 1).float().detach().numpy()
226
+ sd35_image = (sd35_image * 255).round().astype("uint8")
227
+ image_pil = Image.fromarray(sd35_image[0])
228
+ sd35_path = f"tst_rv_{seed}.png"
229
+ image_pil.save(sd35_path,optimize=False,compress_level=0)
230
+ upload_to_ftp(sd35_path)
231
+
232
+
233
+ #sd35_path = f"sd35_{seed}.png"
234
+ #sd_image.save(sd35_path,optimize=False,compress_level=0)
235
+ #upload_to_ftp(sd35_path)
236
 
237
  # Convert the generated image to a tensor
238
  #generated_image_tensor = torch.tensor([np.array(sd_image).transpose(2, 0, 1)]).to('cuda') / 255.0
 
303
  i += 1
304
  return result, seed, image_path, enhanced_prompt
305
 
 
306
  with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
307
  with gr.Column(elem_id="col-container"):
308
  gr.Markdown(" # Text-to-Text-to-Image StableDiffusion 3.5 Medium (with refine)")