prithivMLmods commited on
Commit
7bc6034
·
verified ·
1 Parent(s): 4a9305f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -33,15 +33,22 @@ vae = AutoencoderKL.from_pretrained(
33
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
  ).to("cuda")
35
 
36
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
37
- "SG161222/RealVisXL_V5.0_Lightning",
38
- torch_dtype=torch.float16,
39
- vae=vae,
40
- controlnet=model,
41
- variant="fp16",
42
- ).to("cuda")
43
-
44
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
45
 
46
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
47
  target_size = (width, height)
@@ -237,11 +244,18 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
237
  label="Input Image"
238
  )
239
 
 
 
 
 
 
 
 
240
  with gr.Row():
241
  with gr.Column(scale=2):
242
  prompt_input = gr.Textbox(label="Prompt (Optional)")
243
  with gr.Column(scale=1):
244
- run_button = gr.Button("Generate")
245
 
246
  with gr.Row():
247
  target_ratio = gr.Radio(
@@ -351,15 +365,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
351
  queue=False
352
  )
353
 
354
- run_button.click(
355
  fn=clear_result,
356
  inputs=None,
357
  outputs=result,
358
  ).then(
359
  fn=infer,
360
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
361
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
362
- overlap_left, overlap_right, overlap_top, overlap_bottom],
 
 
363
  outputs=result,
364
  ).then(
365
  fn=lambda x, history: update_history(x, history),
@@ -373,9 +389,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
373
  outputs=result,
374
  ).then(
375
  fn=infer,
376
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
377
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
378
- overlap_left, overlap_right, overlap_top, overlap_bottom],
 
 
379
  outputs=result,
380
  ).then(
381
  fn=lambda x, history: update_history(x, history),
 
33
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
  ).to("cuda")
35
 
36
+ # --- Define available pipelines ---
37
+ model_mapping = {
38
+ "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
39
+ "RealVisXL V4.0 Lightning": "SG161222/RealVisXL_V4.0_Lightning",
40
+ }
41
+ pipelines = {}
42
+ for name, repo in model_mapping.items():
43
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
44
+ repo,
45
+ torch_dtype=torch.float16,
46
+ vae=vae,
47
+ controlnet=model,
48
+ variant="fp16",
49
+ ).to("cuda")
50
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
51
+ pipelines[name] = pipe
52
 
53
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
54
  target_size = (width, height)
 
244
  label="Input Image"
245
  )
246
 
247
+ with gr.Row():
248
+ model_selector = gr.Dropdown(
249
+ label="Select Model",
250
+ choices=list(pipelines.keys()),
251
+ value="RealVisXL V5.0 Lightning",
252
+ )
253
+
254
  with gr.Row():
255
  with gr.Column(scale=2):
256
  prompt_input = gr.Textbox(label="Prompt (Optional)")
257
  with gr.Column(scale=1):
258
+ run_button = gr.Button("Outpaint Image")
259
 
260
  with gr.Row():
261
  target_ratio = gr.Radio(
 
365
  queue=False
366
  )
367
 
368
+ run_button.click(
369
  fn=clear_result,
370
  inputs=None,
371
  outputs=result,
372
  ).then(
373
  fn=infer,
374
+ inputs=[input_image, width_slider, height_slider, overlap_percentage,
375
+ num_inference_steps, resize_option, custom_resize_percentage,
376
+ prompt_input, alignment_dropdown,
377
+ overlap_left, overlap_right, overlap_top, overlap_bottom,
378
+ model_selector],
379
  outputs=result,
380
  ).then(
381
  fn=lambda x, history: update_history(x, history),
 
389
  outputs=result,
390
  ).then(
391
  fn=infer,
392
+ inputs=[input_image, width_slider, height_slider, overlap_percentage,
393
+ num_inference_steps, resize_option, custom_resize_percentage,
394
+ prompt_input, alignment_dropdown,
395
+ overlap_left, overlap_right, overlap_top, overlap_bottom,
396
+ model_selector],
397
  outputs=result,
398
  ).then(
399
  fn=lambda x, history: update_history(x, history),