Spaces:
Runtime error
Runtime error
erwold
commited on
Commit
·
8e99946
1
Parent(s):
0ded2d6
Initial Commit
Browse files
app.py
CHANGED
@@ -177,37 +177,59 @@ class FluxInterface:
|
|
177 |
|
178 |
def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
|
179 |
try:
|
|
|
|
|
|
|
|
|
|
|
180 |
if seed is not None:
|
181 |
torch.manual_seed(seed)
|
|
|
182 |
|
183 |
self.load_models()
|
|
|
184 |
|
185 |
# Get dimensions from aspect ratio
|
186 |
if aspect_ratio not in ASPECT_RATIOS:
|
187 |
raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
|
188 |
width, height = ASPECT_RATIOS[aspect_ratio]
|
|
|
189 |
|
190 |
# Process input image
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
# Generate images
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
211 |
except Exception as e:
|
212 |
print(f"Error during generation: {str(e)}")
|
213 |
raise gr.Error(f"Generation failed: {str(e)}")
|
@@ -327,6 +349,7 @@ with gr.Blocks(
|
|
327 |
allow_preview=True,
|
328 |
preview=True
|
329 |
)
|
|
|
330 |
|
331 |
with gr.Row(elem_classes="footer"):
|
332 |
gr.Markdown("""
|
@@ -339,8 +362,20 @@ with gr.Blocks(
|
|
339 |
""")
|
340 |
|
341 |
# Set up the generation function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
submit_btn.click(
|
343 |
-
fn=
|
344 |
inputs=[
|
345 |
input_image,
|
346 |
prompt,
|
@@ -350,7 +385,10 @@ with gr.Blocks(
|
|
350 |
seed,
|
351 |
aspect_ratio
|
352 |
],
|
353 |
-
outputs=
|
|
|
|
|
|
|
354 |
show_progress="minimal"
|
355 |
)
|
356 |
|
|
|
177 |
|
178 |
def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
|
179 |
try:
|
180 |
+
print(f"Starting generation with prompt: {prompt}, guidance_scale: {guidance_scale}, steps: {num_inference_steps}")
|
181 |
+
|
182 |
+
if input_image is None:
|
183 |
+
raise ValueError("No input image provided")
|
184 |
+
|
185 |
if seed is not None:
|
186 |
torch.manual_seed(seed)
|
187 |
+
print(f"Set random seed to: {seed}")
|
188 |
|
189 |
self.load_models()
|
190 |
+
print("Models loaded successfully")
|
191 |
|
192 |
# Get dimensions from aspect ratio
|
193 |
if aspect_ratio not in ASPECT_RATIOS:
|
194 |
raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}")
|
195 |
width, height = ASPECT_RATIOS[aspect_ratio]
|
196 |
+
print(f"Using dimensions: {width}x{height}")
|
197 |
|
198 |
# Process input image
|
199 |
+
try:
|
200 |
+
input_image = self.resize_image(input_image)
|
201 |
+
print(f"Input image resized to: {input_image.size}")
|
202 |
+
qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
|
203 |
+
print("Input image processed successfully")
|
204 |
+
except Exception as e:
|
205 |
+
raise RuntimeError(f"Error processing input image: {str(e)}")
|
206 |
|
207 |
+
try:
|
208 |
+
pooled_prompt_embeds = self.compute_text_embeddings("")
|
209 |
+
print("Base text embeddings computed")
|
210 |
+
|
211 |
+
# Get T5 embeddings if prompt is provided
|
212 |
+
t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
|
213 |
+
print("T5 prompt embeddings computed")
|
214 |
+
except Exception as e:
|
215 |
+
raise RuntimeError(f"Error computing embeddings: {str(e)}")
|
216 |
|
217 |
# Generate images
|
218 |
+
try:
|
219 |
+
output_images = self.pipeline(
|
220 |
+
prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1),
|
221 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
222 |
+
t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
|
223 |
+
num_inference_steps=num_inference_steps,
|
224 |
+
guidance_scale=guidance_scale,
|
225 |
+
height=height,
|
226 |
+
width=width,
|
227 |
+
).images
|
228 |
+
|
229 |
+
print("Images generated successfully")
|
230 |
+
return output_images
|
231 |
+
except Exception as e:
|
232 |
+
raise RuntimeError(f"Error generating images: {str(e)}")
|
233 |
except Exception as e:
|
234 |
print(f"Error during generation: {str(e)}")
|
235 |
raise gr.Error(f"Generation failed: {str(e)}")
|
|
|
349 |
allow_preview=True,
|
350 |
preview=True
|
351 |
)
|
352 |
+
error_message = gr.Textbox(visible=False)
|
353 |
|
354 |
with gr.Row(elem_classes="footer"):
|
355 |
gr.Markdown("""
|
|
|
362 |
""")
|
363 |
|
364 |
# Set up the generation function
|
365 |
+
def generate_with_error_handling(*args):
|
366 |
+
try:
|
367 |
+
with gr.Status() as status:
|
368 |
+
status.update(value="Loading models...", visible=True)
|
369 |
+
results = interface.generate(*args)
|
370 |
+
status.update(value="Generation complete!", visible=False)
|
371 |
+
return [results, None]
|
372 |
+
except Exception as e:
|
373 |
+
error_msg = str(e)
|
374 |
+
print(f"Error in generate_with_error_handling: {error_msg}")
|
375 |
+
return [None, gr.Error(error_msg)]
|
376 |
+
|
377 |
submit_btn.click(
|
378 |
+
fn=generate_with_error_handling,
|
379 |
inputs=[
|
380 |
input_image,
|
381 |
prompt,
|
|
|
385 |
seed,
|
386 |
aspect_ratio
|
387 |
],
|
388 |
+
outputs=[
|
389 |
+
output_gallery,
|
390 |
+
error_message
|
391 |
+
],
|
392 |
show_progress="minimal"
|
393 |
)
|
394 |
|