Pijush2023 commited on
Commit
376f0d5
·
verified ·
1 Parent(s): 7f2efad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -688,7 +688,7 @@ def generate_map(location_names):
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
  # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
  # pipe.to(device)
694
 
@@ -711,15 +711,20 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
711
  # image_3 = generate_image(hardcoded_prompt_3)
712
  # return image_1, image_2, image_3
713
 
 
 
714
  from diffusers import FluxPipeline
715
 
 
 
 
716
  # Function to initialize Flux bot model
717
  def initialize_flux_bot():
718
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
719
- pipe.enable_model_cpu_offload() # Saves VRAM by offloading the model to CPU
720
  return pipe
721
 
722
- # Function to generate image using Flux bot
723
  def generate_image_flux(prompt):
724
  pipe = initialize_flux_bot()
725
  image = pipe(
@@ -727,7 +732,7 @@ def generate_image_flux(prompt):
727
  guidance_scale=0.0,
728
  num_inference_steps=4,
729
  max_sequence_length=256,
730
- generator=torch.Generator("cpu").manual_seed(0)
731
  ).images[0]
732
  return image
733
 
 
688
  map_html = m._repr_html_()
689
  return map_html
690
 
691
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
692
  # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float16)
693
  # pipe.to(device)
694
 
 
711
  # image_3 = generate_image(hardcoded_prompt_3)
712
  # return image_1, image_2, image_3
713
 
714
+ import gradio as gr
715
+ import torch
716
  from diffusers import FluxPipeline
717
 
718
+ # Check if CUDA (GPU) is available, otherwise fallback to CPU
719
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
720
+
721
  # Function to initialize Flux bot model
722
  def initialize_flux_bot():
723
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
724
+ pipe.to(device) # Move the model to the correct device (GPU/CPU)
725
  return pipe
726
 
727
+ # Function to generate image using Flux bot on the specified device
728
  def generate_image_flux(prompt):
729
  pipe = initialize_flux_bot()
730
  image = pipe(
 
732
  guidance_scale=0.0,
733
  num_inference_steps=4,
734
  max_sequence_length=256,
735
+ generator=torch.Generator(device).manual_seed(0)
736
  ).images[0]
737
  return image
738