erwold commited on
Commit
8e99946
·
1 Parent(s): 0ded2d6

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +57 -19
app.py CHANGED
@@ -177,37 +177,59 @@ class FluxInterface:
177
 
178
  def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
179
  try:
 
 
 
 
 
180
  if seed is not None:
181
  torch.manual_seed(seed)
 
182
 
183
  self.load_models()
 
184
 
185
  # Get dimensions from aspect ratio
186
  if aspect_ratio not in ASPECT_RATIOS:
187
  raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
188
  width, height = ASPECT_RATIOS[aspect_ratio]
 
189
 
190
  # Process input image
191
- input_image = self.resize_image(input_image)
192
- qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
193
- pooled_prompt_embeds = self.compute_text_embeddings("")
 
 
 
 
194
 
195
- # Get T5 embeddings if prompt is provided
196
- t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
 
 
 
 
 
 
 
197
 
198
  # Generate images
199
- output_images = self.pipeline(
200
- prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
201
- pooled_prompt_embeds=pooled_prompt_embeds,
202
- t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
203
- num_inference_steps=num_inference_steps,
204
- guidance_scale=guidance_scale,
205
- height=height,
206
- width=width,
207
- ).images
208
-
209
- return output_images
210
-
 
 
 
211
  except Exception as e:
212
  print(f"Error during generation: {str(e)}")
213
  raise gr.Error(f"Generation failed: {str(e)}")
@@ -327,6 +349,7 @@ with gr.Blocks(
327
  allow_preview=True,
328
  preview=True
329
  )
 
330
 
331
  with gr.Row(elem_classes="footer"):
332
  gr.Markdown("""
@@ -339,8 +362,20 @@ with gr.Blocks(
339
  """)
340
 
341
  # Set up the generation function
 
 
 
 
 
 
 
 
 
 
 
 
342
  submit_btn.click(
343
- fn=interface.generate,
344
  inputs=[
345
  input_image,
346
  prompt,
@@ -350,7 +385,10 @@ with gr.Blocks(
350
  seed,
351
  aspect_ratio
352
  ],
353
- outputs=output_gallery,
 
 
 
354
  show_progress="minimal"
355
  )
356
 
 
177
 
178
  def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
179
  try:
180
+ print(f"Starting generation with prompt: {prompt}, guidance_scale: {guidance_scale}, steps: {num_inference_steps}")
181
+
182
+ if input_image is None:
183
+ raise ValueError("No input image provided")
184
+
185
  if seed is not None:
186
  torch.manual_seed(seed)
187
+ print(f"Set random seed to: {seed}")
188
 
189
  self.load_models()
190
+ print("Models loaded successfully")
191
 
192
  # Get dimensions from aspect ratio
193
  if aspect_ratio not in ASPECT_RATIOS:
194
  raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
195
  width, height = ASPECT_RATIOS[aspect_ratio]
196
+ print(f"Using dimensions: {width}x{height}")
197
 
198
  # Process input image
199
+ try:
200
+ input_image = self.resize_image(input_image)
201
+ print(f"Input image resized to: {input_image.size}")
202
+ qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
203
+ print("Input image processed successfully")
204
+ except Exception as e:
205
+ raise RuntimeError(f"Error processing input image: {str(e)}")
206
 
207
+ try:
208
+ pooled_prompt_embeds = self.compute_text_embeddings("")
209
+ print("Base text embeddings computed")
210
+
211
+ # Get T5 embeddings if prompt is provided
212
+ t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
213
+ print("T5 prompt embeddings computed")
214
+ except Exception as e:
215
+ raise RuntimeError(f"Error computing embeddings: {str(e)}")
216
 
217
  # Generate images
218
+ try:
219
+ output_images = self.pipeline(
220
+ prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
221
+ pooled_prompt_embeds=pooled_prompt_embeds,
222
+ t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
223
+ num_inference_steps=num_inference_steps,
224
+ guidance_scale=guidance_scale,
225
+ height=height,
226
+ width=width,
227
+ ).images
228
+
229
+ print("Images generated successfully")
230
+ return output_images
231
+ except Exception as e:
232
+ raise RuntimeError(f"Error generating images: {str(e)}")
233
  except Exception as e:
234
  print(f"Error during generation: {str(e)}")
235
  raise gr.Error(f"Generation failed: {str(e)}")
 
349
  allow_preview=True,
350
  preview=True
351
  )
352
+ error_message = gr.Textbox(visible=False)
353
 
354
  with gr.Row(elem_classes="footer"):
355
  gr.Markdown("""
 
362
  """)
363
 
364
  # Set up the generation function
365
+ def generate_with_error_handling(*args):
366
+ try:
367
+ with gr.Status() as status:
368
+ status.update(value="Loading models...", visible=True)
369
+ results = interface.generate(*args)
370
+ status.update(value="Generation complete!", visible=False)
371
+ return [results, None]
372
+ except Exception as e:
373
+ error_msg = str(e)
374
+ print(f"Error in generate_with_error_handling: {error_msg}")
375
+ return [None, gr.Error(error_msg)]
376
+
377
  submit_btn.click(
378
+ fn=generate_with_error_handling,
379
  inputs=[
380
  input_image,
381
  prompt,
 
385
  seed,
386
  aspect_ratio
387
  ],
388
+ outputs=[
389
+ output_gallery,
390
+ error_message
391
+ ],
392
  show_progress="minimal"
393
  )
394