Anurag Bhardwaj commited on
Commit
b994d40
·
verified ·
1 Parent(s): ac87317

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -35
app.py CHANGED
@@ -5,23 +5,19 @@ from diffusers import DiffusionPipeline
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
- from huggingface_hub import login
9
- from transformers import CLIPImageProcessor # Updated per deprecation warning
10
 
11
- # Initialize with your Hugging Face token
12
- login(token="YOUR_HF_TOKEN")
 
 
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
- # Determine device and appropriate torch_dtype
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
19
-
20
  base_model = "black-forest-labs/FLUX.1-dev"
21
  pipe = DiffusionPipeline.from_pretrained(
22
  base_model,
23
- torch_dtype=torch_dtype,
24
- low_cpu_mem_usage=True # Reduce memory usage during load
25
  )
26
 
27
  # Load LoRA weights
@@ -34,19 +30,12 @@ def load_pipeline():
34
  )
35
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
- # Enable GPU optimizations if on GPU; else, try sequential offloading on CPU
 
38
  if device.type == "cuda":
39
- try:
40
- pipe.enable_xformers_memory_efficient_attention()
41
- except Exception as e:
42
- print("Warning: Could not enable xformers memory efficient attention:", e)
43
- else:
44
- try:
45
- pipe.enable_sequential_cpu_offload()
46
- except Exception as e:
47
- print("Warning: Could not enable sequential CPU offload:", e)
48
-
49
  pipe = pipe.to(device)
 
50
  return pipe, safety_checker, image_processor
51
 
52
  pipe, safety_checker, image_processor = load_pipeline()
@@ -54,8 +43,8 @@ pipe, safety_checker, image_processor = load_pipeline()
54
  def generate_image(
55
  prompt,
56
  seed=42,
57
- width=512, # Lowered default resolution for reduced memory usage
58
- height=512, # Lowered default resolution for reduced memory usage
59
  guidance_scale=6,
60
  steps=28,
61
  progress=gr.Progress()
@@ -65,11 +54,14 @@ def generate_image(
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  generator = torch.Generator(device=device).manual_seed(seed)
67
 
68
- # Auto-add the trigger word if not already present
69
  if "super realism" not in prompt.lower():
70
  prompt = f"Super Realism, {prompt}"
71
 
72
- # Generation without a callback since it's not supported
 
 
 
73
  with torch.inference_mode():
74
  result = pipe(
75
  prompt=prompt,
@@ -77,25 +69,25 @@ def generate_image(
77
  height=height,
78
  guidance_scale=guidance_scale,
79
  num_inference_steps=steps,
80
- generator=generator
 
81
  )
82
  image = result.images[0]
83
 
84
- progress(0.8, desc="Generation complete. Running safety check...")
85
- # Preprocess the image for safety checking
86
  safety_input = image_processor(image, return_tensors="pt")
87
  np_image = np.array(image)
88
 
89
- # Run the safety checker; it returns a tuple where the second element is nsfw flags
90
  _, nsfw_detected = safety_checker(
91
- images=[np_image],
92
  clip_input=safety_input.pixel_values
93
  )
94
 
95
  if nsfw_detected[0]:
96
  return Image.new("RGB", (512, 512)), "NSFW content detected"
97
 
98
- progress(1, desc="Generation successful")
99
  return image, "Generation successful"
100
 
101
  except Exception as e:
@@ -108,9 +100,8 @@ with gr.Blocks() as app:
108
  with gr.Column():
109
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
110
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
111
- # Lower the resolution slider range for less memory-intensive generation
112
- width_input = gr.Slider(256, 1024, value=512, label="Width")
113
- height_input = gr.Slider(256, 1024, value=512, label="Height")
114
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
115
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
116
  submit = gr.Button("Generate")
@@ -125,5 +116,9 @@ with gr.Blocks() as app:
125
  outputs=[output_image, status]
126
  )
127
 
128
- # Use queue without GPU-specific parameters
129
  app.queue(max_size=3).launch()
 
 
 
 
 
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
 
 
8
 
9
+ from torchvision import transforms
10
+ from transformers import CLIPImageProcessor # Updated import
11
+
12
+
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
+ # Load base model
 
 
 
17
  base_model = "black-forest-labs/FLUX.1-dev"
18
  pipe = DiffusionPipeline.from_pretrained(
19
  base_model,
20
+ torch_dtype=torch.bfloat16
 
21
  )
22
 
23
  # Load LoRA weights
 
30
  )
31
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
32
 
33
+ # Optimizations: enable memory efficient attention if using GPU
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  if device.type == "cuda":
36
+ pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
 
 
 
 
 
37
  pipe = pipe.to(device)
38
+
39
  return pipe, safety_checker, image_processor
40
 
41
  pipe, safety_checker, image_processor = load_pipeline()
 
43
  def generate_image(
44
  prompt,
45
  seed=42,
46
+ width=1024,
47
+ height=1024,
48
  guidance_scale=6,
49
  steps=28,
50
  progress=gr.Progress()
 
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  generator = torch.Generator(device=device).manual_seed(seed)
56
 
57
+ # Auto-add trigger words if not present
58
  if "super realism" not in prompt.lower():
59
  prompt = f"Super Realism, {prompt}"
60
 
61
+ # Define the callback function with the proper signature
62
+ def update_progress(step, timestep, latents):
63
+ progress((step + 1) / steps, desc="Generating image...")
64
+
65
  with torch.inference_mode():
66
  result = pipe(
67
  prompt=prompt,
 
69
  height=height,
70
  guidance_scale=guidance_scale,
71
  num_inference_steps=steps,
72
+ generator=generator,
73
+
74
  )
75
  image = result.images[0]
76
 
77
+ progress(1, desc="Safety checking...")
78
+ # Preprocess image for safety checking using the updated image processor
79
  safety_input = image_processor(image, return_tensors="pt")
80
  np_image = np.array(image)
81
 
82
+ # Unpack safety checker results
83
  _, nsfw_detected = safety_checker(
84
+ images=[np_image],
85
  clip_input=safety_input.pixel_values
86
  )
87
 
88
  if nsfw_detected[0]:
89
  return Image.new("RGB", (512, 512)), "NSFW content detected"
90
 
 
91
  return image, "Generation successful"
92
 
93
  except Exception as e:
 
100
  with gr.Column():
101
  prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
102
  seed_input = gr.Slider(0, 1000, value=42, label="Seed")
103
+ width_input = gr.Slider(512, 2048, value=1024, label="Width")
104
+ height_input = gr.Slider(512, 2048, value=1024, label="Height")
 
105
  guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
106
  steps_input = gr.Slider(10, 100, value=28, label="Steps")
107
  submit = gr.Button("Generate")
 
116
  outputs=[output_image, status]
117
  )
118
 
119
+ # Rate limiting: 1 request at a time, with a max queue size of 3
120
  app.queue(max_size=3).launch()
121
+
122
+ # Uncomment for advanced multiple GPU support:
123
+ # pipe.enable_model_cpu_offload()
124
+ # pipe.enable_sequential_cpu_offload()