Pijush2023 commited on
Commit
7f2efad
·
verified ·
1 Parent(s): b8a572b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -40
app.py CHANGED
@@ -688,7 +688,7 @@ def generate_map(location_names):
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
  # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
  # pipe.to(device)
694
 
@@ -711,16 +711,7 @@ def generate_map(location_names):
711
  # image_3 = generate_image(hardcoded_prompt_3)
712
  # return image_1, image_2, image_3
713
 
714
-
715
-
716
- from diffusers import StableDiffusionPipeline, FluxPipeline
717
-
718
- # Function to initialize Stable Diffusion model
719
- def initialize_stable_diffusion():
720
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
721
- pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
722
- pipe.to(device)
723
- return pipe
724
 
725
  # Function to initialize Flux bot model
726
  def initialize_flux_bot():
@@ -728,17 +719,6 @@ def initialize_flux_bot():
728
  pipe.enable_model_cpu_offload() # Saves VRAM by offloading the model to CPU
729
  return pipe
730
 
731
- # Function to generate image using Stable Diffusion
732
- def generate_image_stable_diffusion(prompt):
733
- pipe = initialize_stable_diffusion()
734
- with torch.cuda.amp.autocast():
735
- image = pipe(
736
- prompt,
737
- num_inference_steps=28,
738
- guidance_scale=3.0,
739
- ).images[0]
740
- return image
741
-
742
  # Function to generate image using Flux bot
743
  def generate_image_flux(prompt):
744
  pipe = initialize_flux_bot()
@@ -751,26 +731,22 @@ def generate_image_flux(prompt):
751
  ).images[0]
752
  return image
753
 
754
- # Combined function to handle model switching based on radio button selection
755
- def generate_image(prompt, model_choice):
756
- if model_choice == "IG-1":
757
- return generate_image_flux(prompt)
758
- else: # Default to Stable Diffusion
759
- return generate_image_stable_diffusion(prompt)
760
-
761
  # Hardcoded prompts for the images
762
  hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
763
  hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
764
  hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
765
 
766
- # Function to update images based on the selected model
767
- def update_images(model_choice):
768
- image_1 = generate_image(hardcoded_prompt_1, model_choice)
769
- image_2 = generate_image(hardcoded_prompt_2, model_choice)
770
- image_3 = generate_image(hardcoded_prompt_3, model_choice)
771
  return image_1, image_2, image_3
772
 
773
 
 
 
 
774
  def fetch_local_news():
775
  api_key = os.environ['SERP_API']
776
  url = f'https://serpapi.com/search.json?engine=google_news&q=birmingham headline&api_key={api_key}'
@@ -1428,8 +1404,7 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1428
 
1429
  chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1430
  tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
1431
- # Add a radio button to choose between IG-1 (Flux) and IG-2 (Stable Diffusion), defaulting to IG-1
1432
- model_choice = gr.Radio(label="Select Image Generation Model", choices=["IG-1", "IG-2"], value="IG-1")
1433
  retriever_button = gr.Button("Retriever")
1434
 
1435
  clear_button = gr.Button("Clear")
@@ -1515,12 +1490,13 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1515
  # image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3), width=400, height=400)
1516
 
1517
  # Display images
1518
- image_output_1 = gr.Image(value=generate_image(hardcoded_prompt_1, "IG-1"), width=400, height=400)
1519
- image_output_2 = gr.Image(value=generate_image(hardcoded_prompt_2, "IG-1"), width=400, height=400)
1520
- image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3, "IG-1"), width=400, height=400)
1521
 
 
1522
  refresh_button = gr.Button("Refresh Images")
1523
- refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3], api_name="update_image")
1524
 
1525
  demo.queue()
1526
  demo.launch(show_error=True)
 
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
  # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
  # pipe.to(device)
694
 
 
711
  # image_3 = generate_image(hardcoded_prompt_3)
712
  # return image_1, image_2, image_3
713
 
714
+ from diffusers import FluxPipeline
 
 
 
 
 
 
 
 
 
715
 
716
  # Function to initialize Flux bot model
717
  def initialize_flux_bot():
 
719
  pipe.enable_model_cpu_offload() # Saves VRAM by offloading the model to CPU
720
  return pipe
721
 
 
 
 
 
 
 
 
 
 
 
 
722
  # Function to generate image using Flux bot
723
  def generate_image_flux(prompt):
724
  pipe = initialize_flux_bot()
 
731
  ).images[0]
732
  return image
733
 
 
 
 
 
 
 
 
734
  # Hardcoded prompts for the images
735
  hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
736
  hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
737
  hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
738
 
739
+ # Function to update images
740
+ def update_images():
741
+ image_1 = generate_image_flux(hardcoded_prompt_1)
742
+ image_2 = generate_image_flux(hardcoded_prompt_2)
743
+ image_3 = generate_image_flux(hardcoded_prompt_3)
744
  return image_1, image_2, image_3
745
 
746
 
747
+
748
+
749
+
750
  def fetch_local_news():
751
  api_key = os.environ['SERP_API']
752
  url = f'https://serpapi.com/search.json?engine=google_news&q=birmingham headline&api_key={api_key}'
 
1404
 
1405
  chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="ASK Radar !!!", placeholder="Hey Radar...!!")
1406
  tts_choice = gr.Radio(label="Select TTS System", choices=["Alpha", "Beta"], value="Alpha")
1407
+
 
1408
  retriever_button = gr.Button("Retriever")
1409
 
1410
  clear_button = gr.Button("Clear")
 
1490
  # image_output_3 = gr.Image(value=generate_image(hardcoded_prompt_3), width=400, height=400)
1491
 
1492
  # Display images
1493
+ image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
1494
+ image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
1495
+ image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
1496
 
1497
+ # Refresh button to update images
1498
  refresh_button = gr.Button("Refresh Images")
1499
+ refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
1500
 
1501
  demo.queue()
1502
  demo.launch(show_error=True)