Anurag Bhardwaj commited on
Commit
674cc6d
·
verified ·
1 Parent(s): 5e66759

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -29
app.py CHANGED
@@ -6,21 +6,16 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
 
9
- from transformers import CLIPImageProcessor # Updated per deprecation warning
10
-
11
-
12
 
13
  @lru_cache(maxsize=1)
14
  def load_pipeline():
15
- # Determine device and appropriate torch_dtype
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
18
-
19
  base_model = "black-forest-labs/FLUX.1-dev"
20
  pipe = DiffusionPipeline.from_pretrained(
21
  base_model,
22
- torch_dtype=torch_dtype,
23
- low_cpu_mem_usage=True # Reduce memory usage during load
24
  )
25
 
26
  # Load LoRA weights
@@ -33,19 +28,12 @@ def load_pipeline():
33
  )
34
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
35
 
36
- # Enable GPU optimizations if on GPU; else, try sequential offloading on CPU
 
37
  if device.type == "cuda":
38
- try:
39
- pipe.enable_xformers_memory_efficient_attention()
40
- except Exception as e:
41
- print("Warning: Could not enable xformers memory efficient attention:", e)
42
- else:
43
- try:
44
- pipe.enable_sequential_cpu_offload()
45
- except Exception as e:
46
- print("Warning: Could not enable sequential CPU offload:", e)
47
-
48
  pipe = pipe.to(device)
 
49
  return pipe, safety_checker, image_processor
50
 
51
  pipe, safety_checker, image_processor = load_pipeline()
@@ -53,8 +41,8 @@ pipe, safety_checker, image_processor = load_pipeline()
53
  def generate_image(
54
  prompt,
55
  seed=42,
56
- width=512, # Lowered default resolution
57
- height=512, # Lowered default resolution
58
  guidance_scale=6,
59
  steps=28,
60
  progress=gr.Progress()
@@ -64,10 +52,11 @@ def generate_image(
64
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  generator = torch.Generator(device=device).manual_seed(seed)
66
 
67
- # Auto-add the trigger word if not already present
68
  if "super realism" not in prompt.lower():
69
  prompt = f"Super Realism, {prompt}"
70
 
 
71
  def update_progress(step, timestep, latents):
72
  progress((step + 1) / steps, desc="Generating image...")
73
 
@@ -84,12 +73,13 @@ def generate_image(
84
  image = result.images[0]
85
 
86
  progress(1, desc="Safety checking...")
 
87
  safety_input = image_processor(image, return_tensors="pt")
88
  np_image = np.array(image)
89
 
90
- # Run safety checker; it returns a tuple where the second element is nsfw flags
91
  _, nsfw_detected = safety_checker(
92
- images=[np_image],
93
  clip_input=safety_input.pixel_values
94
  )
95
 
@@ -108,9 +98,8 @@ with gr.Blocks() as app:
108
  with gr.Column():
109
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
110
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
111
- # Lower the resolution slider range for less memory-intensive generation
112
- width_input = gr.Slider(256, 1024, value=512, label="Width")
113
- height_input = gr.Slider(256, 1024, value=512, label="Height")
114
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
115
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
116
  submit = gr.Button("Generate")
@@ -125,5 +114,9 @@ with gr.Blocks() as app:
125
  outputs=[output_image, status]
126
  )
127
 
128
- # Use queue without GPU-specific parameters
129
  app.queue(max_size=3).launch()
 
 
 
 
 
6
  from functools import lru_cache
7
  from PIL import Image
8
 
9
+ from torchvision import transforms
10
+ from transformers import CLIPImageProcessor # Updated import
 
11
 
12
  @lru_cache(maxsize=1)
13
  def load_pipeline():
14
+ # Load base model
 
 
 
15
  base_model = "black-forest-labs/FLUX.1-dev"
16
  pipe = DiffusionPipeline.from_pretrained(
17
  base_model,
18
+ torch_dtype=torch.bfloat16
 
19
  )
20
 
21
  # Load LoRA weights
 
28
  )
29
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
 
31
+ # Optimizations: enable memory efficient attention if using GPU
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  if device.type == "cuda":
34
+ pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
 
 
 
 
 
35
  pipe = pipe.to(device)
36
+
37
  return pipe, safety_checker, image_processor
38
 
39
  pipe, safety_checker, image_processor = load_pipeline()
 
41
  def generate_image(
42
  prompt,
43
  seed=42,
44
+ width=1024,
45
+ height=1024,
46
  guidance_scale=6,
47
  steps=28,
48
  progress=gr.Progress()
 
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  generator = torch.Generator(device=device).manual_seed(seed)
54
 
55
+ # Auto-add trigger words if not present
56
  if "super realism" not in prompt.lower():
57
  prompt = f"Super Realism, {prompt}"
58
 
59
+ # Define the callback function with the proper signature
60
  def update_progress(step, timestep, latents):
61
  progress((step + 1) / steps, desc="Generating image...")
62
 
 
73
  image = result.images[0]
74
 
75
  progress(1, desc="Safety checking...")
76
+ # Preprocess image for safety checking using the updated image processor
77
  safety_input = image_processor(image, return_tensors="pt")
78
  np_image = np.array(image)
79
 
80
+ # Unpack safety checker results
81
  _, nsfw_detected = safety_checker(
82
+ images=[np_image],
83
  clip_input=safety_input.pixel_values
84
  )
85
 
 
98
  with gr.Column():
99
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
100
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
101
+ width_input = gr.Slider(512, 2048, value=1024, label="Width")
102
+ height_input = gr.Slider(512, 2048, value=1024, label="Height")
 
103
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
104
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
105
  submit = gr.Button("Generate")
 
114
  outputs=[output_image, status]
115
  )
116
 
117
+ # Rate limiting: 1 request at a time, with a max queue size of 3
118
  app.queue(max_size=3).launch()
119
+
120
+ # Uncomment for advanced multiple GPU support:
121
+ # pipe.enable_model_cpu_offload()
122
+ # pipe.enable_sequential_cpu_offload()