Anurag Bhardwaj commited on
Commit
a351ff0
·
verified ·
1 Parent(s): 188c627

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -7,16 +7,21 @@ from functools import lru_cache
7
  from PIL import Image
8
 
9
  from torchvision import transforms
10
- from transformers import CLIPImageProcessor # Updated import
 
11
 
12
 
13
  @lru_cache(maxsize=1)
14
  def load_pipeline():
15
- # Load base model
 
 
 
 
16
  base_model = "black-forest-labs/FLUX.1-dev"
17
  pipe = DiffusionPipeline.from_pretrained(
18
  base_model,
19
- torch_dtype=torch.bfloat16
20
  )
21
 
22
  # Load LoRA weights
@@ -29,12 +34,14 @@ def load_pipeline():
29
  )
30
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
31
 
32
- # Optimizations: enable memory efficient attention if using GPU
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  if device.type == "cuda":
35
- pipe.enable_xformers_memory_efficient_attention()
36
- pipe = pipe.to(device)
 
 
37
 
 
38
  return pipe, safety_checker, image_processor
39
 
40
  pipe, safety_checker, image_processor = load_pipeline()
@@ -53,11 +60,11 @@ def generate_image(
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  generator = torch.Generator(device=device).manual_seed(seed)
55
 
56
- # Auto-add trigger words if not present
57
  if "super realism" not in prompt.lower():
58
  prompt = f"Super Realism, {prompt}"
59
 
60
- # Define the callback function with the proper signature
61
  def update_progress(step, timestep, latents):
62
  progress((step + 1) / steps, desc="Generating image...")
63
 
@@ -74,11 +81,11 @@ def generate_image(
74
  image = result.images[0]
75
 
76
  progress(1, desc="Safety checking...")
77
- # Preprocess image for safety checking using the updated image processor
78
  safety_input = image_processor(image, return_tensors="pt")
79
  np_image = np.array(image)
80
 
81
- # Unpack safety checker results
82
  _, nsfw_detected = safety_checker(
83
  images=[np_image],
84
  clip_input=safety_input.pixel_values
@@ -115,9 +122,5 @@ with gr.Blocks() as app:
115
  outputs=[output_image, status]
116
  )
117
 
118
- # Rate limiting: 1 request at a time, with a max queue size of 3
119
  app.queue(max_size=3).launch()
120
-
121
- # Uncomment for advanced multiple GPU support:
122
- # pipe.enable_model_cpu_offload()
123
- # pipe.enable_sequential_cpu_offload()
 
7
  from PIL import Image
8
 
9
  from torchvision import transforms
10
+ from transformers import CLIPImageProcessor # Updated per deprecation warning
11
+ import os
12
 
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
+ # Determine device and set torch_dtype accordingly
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
19
+
20
+ # Load base model with the appropriate dtype
21
  base_model = "black-forest-labs/FLUX.1-dev"
22
  pipe = DiffusionPipeline.from_pretrained(
23
  base_model,
24
+ torch_dtype=torch_dtype
25
  )
26
 
27
  # Load LoRA weights
 
34
  )
35
  image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
+ # Enable GPU-only optimizations if a GPU is available
 
38
  if device.type == "cuda":
39
+ try:
40
+ pipe.enable_xformers_memory_efficient_attention()
41
+ except Exception as e:
42
+ print("Warning: Could not enable xformers memory efficient attention:", e)
43
 
44
+ pipe = pipe.to(device)
45
  return pipe, safety_checker, image_processor
46
 
47
  pipe, safety_checker, image_processor = load_pipeline()
 
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
 
63
+ # Ensure the trigger word is present
64
  if "super realism" not in prompt.lower():
65
  prompt = f"Super Realism, {prompt}"
66
 
67
+ # Define a callback to update progress
68
  def update_progress(step, timestep, latents):
69
  progress((step + 1) / steps, desc="Generating image...")
70
 
 
81
  image = result.images[0]
82
 
83
  progress(1, desc="Safety checking...")
84
+ # Preprocess the image for safety checking
85
  safety_input = image_processor(image, return_tensors="pt")
86
  np_image = np.array(image)
87
 
88
+ # Run the safety checker
89
  _, nsfw_detected = safety_checker(
90
  images=[np_image],
91
  clip_input=safety_input.pixel_values
 
122
  outputs=[output_image, status]
123
  )
124
 
125
+ # Queue without GPU-specific arguments
126
  app.queue(max_size=3).launch()