Anurag Bhardwaj commited on
Commit
ac87317
·
verified ·
1 Parent(s): 674cc6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -5,17 +5,23 @@ from diffusers import DiffusionPipeline
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
 
 
8
 
9
- from torchvision import transforms
10
- from transformers import CLIPImageProcessor # Updated import
11
 
12
  @lru_cache(maxsize=1)
13
  def load_pipeline():
14
- # Load base model
 
 
 
15
  base_model = "black-forest-labs/FLUX.1-dev"
16
  pipe = DiffusionPipeline.from_pretrained(
17
  base_model,
18
- torch_dtype=torch.bfloat16
 
19
  )
20
 
21
  # Load LoRA weights
@@ -28,12 +34,19 @@ def load_pipeline():
28
  )
29
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
 
31
- # Optimizations: enable memory efficient attention if using GPU
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  if device.type == "cuda":
34
- pipe.enable_xformers_memory_efficient_attention()
35
- pipe = pipe.to(device)
 
 
 
 
 
 
 
36
 
 
37
  return pipe, safety_checker, image_processor
38
 
39
  pipe, safety_checker, image_processor = load_pipeline()
@@ -41,8 +54,8 @@ pipe, safety_checker, image_processor = load_pipeline()
41
  def generate_image(
42
  prompt,
43
  seed=42,
44
- width=1024,
45
- height=1024,
46
  guidance_scale=6,
47
  steps=28,
48
  progress=gr.Progress()
@@ -52,14 +65,11 @@ def generate_image(
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  generator = torch.Generator(device=device).manual_seed(seed)
54
 
55
- # Auto-add trigger words if not present
56
  if "super realism" not in prompt.lower():
57
  prompt = f"Super Realism, {prompt}"
58
 
59
- # Define the callback function with the proper signature
60
- def update_progress(step, timestep, latents):
61
- progress((step + 1) / steps, desc="Generating image...")
62
-
63
  with torch.inference_mode():
64
  result = pipe(
65
  prompt=prompt,
@@ -67,25 +77,25 @@ def generate_image(
67
  height=height,
68
  guidance_scale=guidance_scale,
69
  num_inference_steps=steps,
70
- generator=generator,
71
- callback=update_progress
72
  )
73
  image = result.images[0]
74
 
75
- progress(1, desc="Safety checking...")
76
- # Preprocess image for safety checking using the updated image processor
77
  safety_input = image_processor(image, return_tensors="pt")
78
  np_image = np.array(image)
79
 
80
- # Unpack safety checker results
81
  _, nsfw_detected = safety_checker(
82
- images=[np_image],
83
  clip_input=safety_input.pixel_values
84
  )
85
 
86
  if nsfw_detected[0]:
87
  return Image.new("RGB", (512, 512)), "NSFW content detected"
88
 
 
89
  return image, "Generation successful"
90
 
91
  except Exception as e:
@@ -98,8 +108,9 @@ with gr.Blocks() as app:
98
  with gr.Column():
99
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
100
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
101
- width_input = gr.Slider(512, 2048, value=1024, label="Width")
102
- height_input = gr.Slider(512, 2048, value=1024, label="Height")
 
103
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
104
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
105
  submit = gr.Button("Generate")
@@ -114,9 +125,5 @@ with gr.Blocks() as app:
114
  outputs=[output_image, status]
115
  )
116
 
117
- # Rate limiting: 1 request at a time, with a max queue size of 3
118
  app.queue(max_size=3).launch()
119
-
120
- # Uncomment for advanced multiple GPU support:
121
- # pipe.enable_model_cpu_offload()
122
- # pipe.enable_sequential_cpu_offload()
 
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
+ from huggingface_hub import login
9
+ from transformers import CLIPImageProcessor # Updated per deprecation warning
10
 
11
+ # Initialize with your Hugging Face token
12
+ login(token="YOUR_HF_TOKEN")
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
+ # Determine device and appropriate torch_dtype
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
19
+
20
  base_model = "black-forest-labs/FLUX.1-dev"
21
  pipe = DiffusionPipeline.from_pretrained(
22
  base_model,
23
+ torch_dtype=torch_dtype,
24
+ low_cpu_mem_usage=True # Reduce memory usage during load
25
  )
26
 
27
  # Load LoRA weights
 
34
  )
35
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
+ # Enable GPU optimizations if on GPU; else, try sequential offloading on CPU
 
38
  if device.type == "cuda":
39
+ try:
40
+ pipe.enable_xformers_memory_efficient_attention()
41
+ except Exception as e:
42
+ print("Warning: Could not enable xformers memory efficient attention:", e)
43
+ else:
44
+ try:
45
+ pipe.enable_sequential_cpu_offload()
46
+ except Exception as e:
47
+ print("Warning: Could not enable sequential CPU offload:", e)
48
 
49
+ pipe = pipe.to(device)
50
  return pipe, safety_checker, image_processor
51
 
52
  pipe, safety_checker, image_processor = load_pipeline()
 
54
  def generate_image(
55
  prompt,
56
  seed=42,
57
+ width=512, # Lowered default resolution for reduced memory usage
58
+ height=512, # Lowered default resolution for reduced memory usage
59
  guidance_scale=6,
60
  steps=28,
61
  progress=gr.Progress()
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  generator = torch.Generator(device=device).manual_seed(seed)
67
 
68
+ # Auto-add the trigger word if not already present
69
  if "super realism" not in prompt.lower():
70
  prompt = f"Super Realism, {prompt}"
71
 
72
+ # Generation without a callback since it's not supported
 
 
 
73
  with torch.inference_mode():
74
  result = pipe(
75
  prompt=prompt,
 
77
  height=height,
78
  guidance_scale=guidance_scale,
79
  num_inference_steps=steps,
80
+ generator=generator
 
81
  )
82
  image = result.images[0]
83
 
84
+ progress(0.8, desc="Generation complete. Running safety check...")
85
+ # Preprocess the image for safety checking
86
  safety_input = image_processor(image, return_tensors="pt")
87
  np_image = np.array(image)
88
 
89
+ # Run the safety checker; it returns a tuple where the second element is nsfw flags
90
  _, nsfw_detected = safety_checker(
91
+ images=[np_image],
92
  clip_input=safety_input.pixel_values
93
  )
94
 
95
  if nsfw_detected[0]:
96
  return Image.new("RGB", (512, 512)), "NSFW content detected"
97
 
98
+ progress(1, desc="Generation successful")
99
  return image, "Generation successful"
100
 
101
  except Exception as e:
 
108
  with gr.Column():
109
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
110
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
111
+ # Lower the resolution slider range for less memory-intensive generation
112
+ width_input = gr.Slider(256, 1024, value=512, label="Width")
113
+ height_input = gr.Slider(256, 1024, value=512, label="Height")
114
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
115
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
116
  submit = gr.Button("Generate")
 
125
  outputs=[output_image, status]
126
  )
127
 
128
+ # Use queue without GPU-specific parameters
129
  app.queue(max_size=3).launch()