Anurag Bhardwaj commited on
Commit
796e120
·
verified ·
1 Parent(s): b46eefe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -5,11 +5,12 @@ from diffusers import DiffusionPipeline
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
-
9
  from torchvision import transforms
10
- from transformers import CLIPFeatureExtractor # Added missing import
11
-
12
 
 
 
13
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
@@ -24,13 +25,11 @@ def load_pipeline():
24
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
25
  pipe.load_lora_weights(lora_repo)
26
 
27
- # Load safety checker and feature extractor
28
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
29
  "CompVis/stable-diffusion-safety-checker"
30
  )
31
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
32
- "openai/clip-vit-base-patch32"
33
- )
34
 
35
  # Optimizations: enable memory efficient attention if using GPU
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -38,9 +37,9 @@ def load_pipeline():
38
  pipe.enable_xformers_memory_efficient_attention()
39
  pipe = pipe.to(device)
40
 
41
- return pipe, safety_checker, feature_extractor
42
 
43
- pipe, safety_checker, feature_extractor = load_pipeline()
44
 
45
  def generate_image(
46
  prompt,
@@ -77,11 +76,11 @@ def generate_image(
77
  image = result.images[0]
78
 
79
  progress(1, desc="Safety checking...")
80
- # Preprocess image for safety checking
81
- safety_input = feature_extractor(image, return_tensors="pt")
82
  np_image = np.array(image)
83
 
84
- # Unpack safety checker results (the safety checker returns a tuple)
85
  _, nsfw_detected = safety_checker(
86
  images=[np_image],
87
  clip_input=safety_input.pixel_values
@@ -120,6 +119,7 @@ with gr.Blocks() as app:
120
 
121
  # Rate limiting: 1 request at a time, with a max queue size of 3
122
  app.queue(max_size=3).launch()
123
- # Uncomment the lines below for advanced multiple GPU support
124
- pipe.enable_model_cpu_offload()
125
- pipe.enable_sequential_cpu_offload()
 
 
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
+ from huggingface_hub import login
9
  from torchvision import transforms
10
+ from transformers import CLIPImageProcessor # Updated import
 
11
 
12
+ # Initialize with your Hugging Face token
13
+ login(token="YOUR_HF_TOKEN")
14
 
15
  @lru_cache(maxsize=1)
16
  def load_pipeline():
 
25
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
26
  pipe.load_lora_weights(lora_repo)
27
 
28
+ # Load safety checker and image processor
29
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
30
  "CompVis/stable-diffusion-safety-checker"
31
  )
32
+ image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
33
 
34
  # Optimizations: enable memory efficient attention if using GPU
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
37
  pipe.enable_xformers_memory_efficient_attention()
38
  pipe = pipe.to(device)
39
 
40
+ return pipe, safety_checker, image_processor
41
 
42
+ pipe, safety_checker, image_processor = load_pipeline()
43
 
44
  def generate_image(
45
  prompt,
 
76
  image = result.images[0]
77
 
78
  progress(1, desc="Safety checking...")
79
+ # Preprocess image for safety checking using the updated image processor
80
+ safety_input = image_processor(image, return_tensors="pt")
81
  np_image = np.array(image)
82
 
83
+ # Unpack safety checker results
84
  _, nsfw_detected = safety_checker(
85
  images=[np_image],
86
  clip_input=safety_input.pixel_values
 
119
 
120
  # Rate limiting: 1 request at a time, with a max queue size of 3
121
  app.queue(max_size=3).launch()
122
+
123
+ # Uncomment for advanced multiple GPU support:
124
+ # pipe.enable_model_cpu_offload()
125
+ # pipe.enable_sequential_cpu_offload()