Anurag Bhardwaj commited on
Commit
5e66759
·
verified ·
1 Parent(s): a351ff0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -6,22 +6,21 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
 
9
- from torchvision import transforms
10
  from transformers import CLIPImageProcessor # Updated per deprecation warning
11
- import os
12
 
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
- # Determine device and set torch_dtype accordingly
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
19
 
20
- # Load base model with the appropriate dtype
21
  base_model = "black-forest-labs/FLUX.1-dev"
22
  pipe = DiffusionPipeline.from_pretrained(
23
  base_model,
24
- torch_dtype=torch_dtype
 
25
  )
26
 
27
  # Load LoRA weights
@@ -34,12 +33,17 @@ def load_pipeline():
34
  )
35
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
- # Enable GPU-only optimizations if a GPU is available
38
  if device.type == "cuda":
39
  try:
40
  pipe.enable_xformers_memory_efficient_attention()
41
  except Exception as e:
42
  print("Warning: Could not enable xformers memory efficient attention:", e)
 
 
 
 
 
43
 
44
  pipe = pipe.to(device)
45
  return pipe, safety_checker, image_processor
@@ -49,8 +53,8 @@ pipe, safety_checker, image_processor = load_pipeline()
49
  def generate_image(
50
  prompt,
51
  seed=42,
52
- width=1024,
53
- height=1024,
54
  guidance_scale=6,
55
  steps=28,
56
  progress=gr.Progress()
@@ -60,11 +64,10 @@ def generate_image(
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
 
63
- # Ensure the trigger word is present
64
  if "super realism" not in prompt.lower():
65
  prompt = f"Super Realism, {prompt}"
66
 
67
- # Define a callback to update progress
68
  def update_progress(step, timestep, latents):
69
  progress((step + 1) / steps, desc="Generating image...")
70
 
@@ -81,13 +84,12 @@ def generate_image(
81
  image = result.images[0]
82
 
83
  progress(1, desc="Safety checking...")
84
- # Preprocess the image for safety checking
85
  safety_input = image_processor(image, return_tensors="pt")
86
  np_image = np.array(image)
87
 
88
- # Run the safety checker
89
  _, nsfw_detected = safety_checker(
90
- images=[np_image],
91
  clip_input=safety_input.pixel_values
92
  )
93
 
@@ -106,8 +108,9 @@ with gr.Blocks() as app:
106
  with gr.Column():
107
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
108
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
109
- width_input = gr.Slider(512, 2048, value=1024, label="Width")
110
- height_input = gr.Slider(512, 2048, value=1024, label="Height")
 
111
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
112
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
113
  submit = gr.Button("Generate")
@@ -122,5 +125,5 @@ with gr.Blocks() as app:
122
  outputs=[output_image, status]
123
  )
124
 
125
- # Queue without GPU-specific arguments
126
  app.queue(max_size=3).launch()
 
6
  from functools import lru_cache
7
  from PIL import Image
8
 
 
9
  from transformers import CLIPImageProcessor # Updated per deprecation warning
10
+
11
 
12
 
13
  @lru_cache(maxsize=1)
14
  def load_pipeline():
15
+ # Determine device and appropriate torch_dtype
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
18
 
 
19
  base_model = "black-forest-labs/FLUX.1-dev"
20
  pipe = DiffusionPipeline.from_pretrained(
21
  base_model,
22
+ torch_dtype=torch_dtype,
23
+ low_cpu_mem_usage=True # Reduce memory usage during load
24
  )
25
 
26
  # Load LoRA weights
 
33
  )
34
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
35
 
36
+ # Enable GPU optimizations if on GPU; else, try sequential offloading on CPU
37
  if device.type == "cuda":
38
  try:
39
  pipe.enable_xformers_memory_efficient_attention()
40
  except Exception as e:
41
  print("Warning: Could not enable xformers memory efficient attention:", e)
42
+ else:
43
+ try:
44
+ pipe.enable_sequential_cpu_offload()
45
+ except Exception as e:
46
+ print("Warning: Could not enable sequential CPU offload:", e)
47
 
48
  pipe = pipe.to(device)
49
  return pipe, safety_checker, image_processor
 
53
  def generate_image(
54
  prompt,
55
  seed=42,
56
+ width=512, # Lowered default resolution
57
+ height=512, # Lowered default resolution
58
  guidance_scale=6,
59
  steps=28,
60
  progress=gr.Progress()
 
64
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  generator = torch.Generator(device=device).manual_seed(seed)
66
 
67
+ # Auto-add the trigger word if not already present
68
  if "super realism" not in prompt.lower():
69
  prompt = f"Super Realism, {prompt}"
70
 
 
71
  def update_progress(step, timestep, latents):
72
  progress((step + 1) / steps, desc="Generating image...")
73
 
 
84
  image = result.images[0]
85
 
86
  progress(1, desc="Safety checking...")
 
87
  safety_input = image_processor(image, return_tensors="pt")
88
  np_image = np.array(image)
89
 
90
+ # Run safety checker; it returns a tuple where the second element is nsfw flags
91
  _, nsfw_detected = safety_checker(
92
+ images=[np_image],
93
  clip_input=safety_input.pixel_values
94
  )
95
 
 
108
  with gr.Column():
109
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
110
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
111
+ # Lower the resolution slider range for less memory-intensive generation
112
+ width_input = gr.Slider(256, 1024, value=512, label="Width")
113
+ height_input = gr.Slider(256, 1024, value=512, label="Height")
114
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
115
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
116
  submit = gr.Button("Generate")
 
125
  outputs=[output_image, status]
126
  )
127
 
128
+ # Use queue without GPU-specific parameters
129
  app.queue(max_size=3).launch()