Pijush2023 commited on
Commit
b8a572b
·
verified ·
1 Parent(s): 8c84430

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -14
app.py CHANGED
@@ -688,11 +688,49 @@ def generate_map(location_names):
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
- pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
- pipe.to(device)
694
-
695
- def generate_image(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  with torch.cuda.amp.autocast():
697
  image = pipe(
698
  prompt,
@@ -701,14 +739,35 @@ def generate_image(prompt):
701
  ).images[0]
702
  return image
703
 
704
- hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in th style of Michael Mann"
705
- hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in th style of Michael Mann"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
707
 
708
- def update_images():
709
- image_1 = generate_image(hardcoded_prompt_1)
710
- image_2 = generate_image(hardcoded_prompt_2)
711
- image_3 = generate_image(hardcoded_prompt_3)
 
712
  return image_1, image_2, image_3
713
 
714
 
@@ -1369,6 +1428,8 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1369
 
1370
  chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1371
  tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
 
 
1372
  retriever_button = gr.Button("Retriever")
1373
 
1374
  clear_button = gr.Button("Clear")
@@ -1449,9 +1510,14 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1449
  events_output = gr.HTML(value=fetch_local_events())
1450
 
1451
  with gr.Column():
1452
- image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
1453
- image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2), width=400, height=400)
1454
- image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3), width=400, height=400)
 
 
 
 
 
1455
 
1456
  refresh_button = gr.Button("Refresh Images")
1457
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")
 
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
+ # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
+ # pipe.to(device)
694
+
695
+ # def generate_image(prompt):
696
+ # with torch.cuda.amp.autocast():
697
+ # image = pipe(
698
+ # prompt,
699
+ # num_inference_steps=28,
700
+ # guidance_scale=3.0,
701
+ # ).images[0]
702
+ # return image
703
+
704
+ # hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in th style of Michael Mann"
705
+ # hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in th style of Michael Mann"
706
+ # hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
707
+
708
+ # def update_images():
709
+ # image_1 = generate_image(hardcoded_prompt_1)
710
+ # image_2 = generate_image(hardcoded_prompt_2)
711
+ # image_3 = generate_image(hardcoded_prompt_3)
712
+ # return image_1, image_2, image_3
713
+
714
+
715
+
716
+ from diffusers import StableDiffusionPipeline, FluxPipeline
717
+
718
+ # Function to initialize Stable Diffusion model
719
+ def initialize_stable_diffusion():
720
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
721
+ pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
722
+ pipe.to(device)
723
+ return pipe
724
+
725
+ # Function to initialize Flux bot model
726
+ def initialize_flux_bot():
727
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
728
+ pipe.enable_model_cpu_offload() # Saves VRAM by offloading the model to CPU
729
+ return pipe
730
+
731
+ # Function to generate image using Stable Diffusion
732
+ def generate_image_stable_diffusion(prompt):
733
+ pipe = initialize_stable_diffusion()
734
  with torch.cuda.amp.autocast():
735
  image = pipe(
736
  prompt,
 
739
  ).images[0]
740
  return image
741
 
742
+ # Function to generate image using Flux bot
743
+ def generate_image_flux(prompt):
744
+ pipe = initialize_flux_bot()
745
+ image = pipe(
746
+ prompt,
747
+ guidance_scale=0.0,
748
+ num_inference_steps=4,
749
+ max_sequence_length=256,
750
+ generator=torch.Generator("cpu").manual_seed(0)
751
+ ).images[0]
752
+ return image
753
+
754
+ # Combined function to handle model switching based on radio button selection
755
+ def generate_image(prompt, model_choice):
756
+ if model_choice == "IG-1":
757
+ return generate_image_flux(prompt)
758
+ else: # Default to Stable Diffusion
759
+ return generate_image_stable_diffusion(prompt)
760
+
761
+ # Hardcoded prompts for the images
762
+ hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
763
+ hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
764
  hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
765
 
766
+ # Function to update images based on the selected model
767
+ def update_images(model_choice):
768
+ image_1 = generate_image(hardcoded_prompt_1, model_choice)
769
+ image_2 = generate_image(hardcoded_prompt_2, model_choice)
770
+ image_3 = generate_image(hardcoded_prompt_3, model_choice)
771
  return image_1, image_2, image_3
772
 
773
 
 
1428
 
1429
  chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1430
  tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
1431
+ # Add a radio button to choose between IG-1 (Flux) and IG-2 (Stable Diffusion), defaulting to IG-1
1432
+ model_choice = gr.Radio(label="Select Image Generation Model", choices=["IG-1", "IG-2"], value="IG-1")
1433
  retriever_button = gr.Button("Retriever")
1434
 
1435
  clear_button = gr.Button("Clear")
 
1510
  events_output = gr.HTML(value=fetch_local_events())
1511
 
1512
  with gr.Column():
1513
+ # image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1), width=400, height=400)
1514
+ # image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2), width=400, height=400)
1515
+ # image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3), width=400, height=400)
1516
+
1517
+ # Display images
1518
+ image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1, "IG-1"), width=400, height=400)
1519
+ image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2, "IG-1"), width=400, height=400)
1520
+ image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3, "IG-1"), width=400, height=400)
1521
 
1522
  refresh_button = gr.Button("Refresh Images")
1523
  refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")