Anurag Bhardwaj commited on
Commit
4570550
·
verified ·
1 Parent(s): fb72caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -27
app.py CHANGED
@@ -5,10 +5,12 @@ from diffusers import DiffusionPipeline
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
 
8
  from torchvision import transforms
 
 
9
 
10
 
11
- # Cache pipeline loading to improve performance
12
  @lru_cache(maxsize=1)
13
  def load_pipeline():
14
  # Load base model
@@ -22,7 +24,7 @@ def load_pipeline():
22
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
23
  pipe.load_lora_weights(lora_repo)
24
 
25
- # Load safety checker
26
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
27
  "CompVis/stable-diffusion-safety-checker"
28
  )
@@ -30,9 +32,11 @@ def load_pipeline():
30
  "openai/clip-vit-base-patch32"
31
  )
32
 
33
- # Optimizations
34
- pipe.enable_xformers_memory_efficient_attention()
35
- pipe = pipe.to("cuda")
 
 
36
 
37
  return pipe, safety_checker, feature_extractor
38
 
@@ -49,19 +53,19 @@ def generate_image(
49
  ):
50
  try:
51
  progress(0, desc="Initializing...")
52
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
53
 
54
- # Auto-add trigger words
55
  if "super realism" not in prompt.lower():
56
  prompt = f"Super Realism, {prompt}"
57
 
58
- # Create callback for progress updates
59
- def update_progress(step, _, __):
60
  progress((step + 1) / steps, desc="Generating image...")
61
 
62
- # Generate image
63
  with torch.inference_mode():
64
- image = pipe(
65
  prompt=prompt,
66
  width=width,
67
  height=height,
@@ -69,18 +73,21 @@ def generate_image(
69
  num_inference_steps=steps,
70
  generator=generator,
71
  callback=update_progress
72
- ).images[0]
 
73
 
74
- # Safety check
75
  progress(1, desc="Safety checking...")
 
76
  safety_input = feature_extractor(image, return_tensors="pt")
77
  np_image = np.array(image)
78
- safety_result = safety_checker(
 
 
79
  images=[np_image],
80
  clip_input=safety_input.pixel_values
81
  )
82
 
83
- if safety_result.nsfw[0]:
84
  return Image.new("RGB", (512, 512)), "NSFW content detected"
85
 
86
  return image, "Generation successful"
@@ -88,18 +95,17 @@ def generate_image(
88
  except Exception as e:
89
  return Image.new("RGB", (512, 512)), f"Error: {str(e)}"
90
 
91
- # Create Gradio interface with rate limiting
92
  with gr.Blocks() as app:
93
  gr.Markdown("# Flux Super Realism Generator")
94
 
95
  with gr.Row():
96
  with gr.Column():
97
- prompt = gr.Textbox(label="Prompt", value="A portrait of a person")
98
- seed = gr.Slider(0, 1000, value=42, label="Seed")
99
- width = gr.Slider(512, 2048, value=1024, label="Width")
100
- height = gr.Slider(512, 2048, value=1024, label="Height")
101
- guidance = gr.Slider(1, 20, value=6, label="Guidance Scale")
102
- steps = gr.Slider(10, 100, value=28, label="Steps")
103
  submit = gr.Button("Generate")
104
 
105
  with gr.Column():
@@ -108,13 +114,13 @@ with gr.Blocks() as app:
108
 
109
  submit.click(
110
  generate_image,
111
- inputs=[prompt, seed, width, height, guidance, steps],
112
  outputs=[output_image, status]
113
  )
114
 
115
- # Rate limiting example (1 request every 30 seconds)
116
  app.queue(concurrency_count=1, max_size=3).launch()
117
 
118
- # For multiple GPU support (advanced)
119
- # pipe.enable_model_cpu_offload()
120
- # pipe.enable_sequential_cpu_offload()
 
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
  from functools import lru_cache
7
  from PIL import Image
8
+
9
  from torchvision import transforms
10
+ from transformers import CLIPFeatureExtractor # Added missing import
11
+
12
 
13
 
 
14
  @lru_cache(maxsize=1)
15
  def load_pipeline():
16
  # Load base model
 
24
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
25
  pipe.load_lora_weights(lora_repo)
26
 
27
+ # Load safety checker and feature extractor
28
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
29
  "CompVis/stable-diffusion-safety-checker"
30
  )
 
32
  "openai/clip-vit-base-patch32"
33
  )
34
 
35
+ # Optimizations: enable memory efficient attention if using GPU
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ if device.type == "cuda":
38
+ pipe.enable_xformers_memory_efficient_attention()
39
+ pipe = pipe.to(device)
40
 
41
  return pipe, safety_checker, feature_extractor
42
 
 
53
  ):
54
  try:
55
  progress(0, desc="Initializing...")
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ generator = torch.Generator(device=device).manual_seed(seed)
58
 
59
+ # Auto-add trigger words if not present
60
  if "super realism" not in prompt.lower():
61
  prompt = f"Super Realism, {prompt}"
62
 
63
+ # Define the callback function with the proper signature
64
+ def update_progress(step, timestep, latents):
65
  progress((step + 1) / steps, desc="Generating image...")
66
 
 
67
  with torch.inference_mode():
68
+ result = pipe(
69
  prompt=prompt,
70
  width=width,
71
  height=height,
 
73
  num_inference_steps=steps,
74
  generator=generator,
75
  callback=update_progress
76
+ )
77
+ image = result.images[0]
78
 
 
79
  progress(1, desc="Safety checking...")
80
+ # Preprocess image for safety checking
81
  safety_input = feature_extractor(image, return_tensors="pt")
82
  np_image = np.array(image)
83
+
84
+ # Unpack safety checker results (the safety checker returns a tuple)
85
+ _, nsfw_detected = safety_checker(
86
  images=[np_image],
87
  clip_input=safety_input.pixel_values
88
  )
89
 
90
+ if nsfw_detected[0]:
91
  return Image.new("RGB", (512, 512)), "NSFW content detected"
92
 
93
  return image, "Generation successful"
 
95
  except Exception as e:
96
  return Image.new("RGB", (512, 512)), f"Error: {str(e)}"
97
 
 
98
  with gr.Blocks() as app:
99
  gr.Markdown("# Flux Super Realism Generator")
100
 
101
  with gr.Row():
102
  with gr.Column():
103
+ prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
104
+ seed_input = gr.Slider(0, 1000, value=42, label="Seed")
105
+ width_input = gr.Slider(512, 2048, value=1024, label="Width")
106
+ height_input = gr.Slider(512, 2048, value=1024, label="Height")
107
+ guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
108
+ steps_input = gr.Slider(10, 100, value=28, label="Steps")
109
  submit = gr.Button("Generate")
110
 
111
  with gr.Column():
 
114
 
115
  submit.click(
116
  generate_image,
117
+ inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
118
  outputs=[output_image, status]
119
  )
120
 
121
+ # Rate limiting: 1 request at a time, with a max queue size of 3
122
  app.queue(concurrency_count=1, max_size=3).launch()
123
 
124
+ # Uncomment the lines below for advanced multiple GPU support
125
+ pipe.enable_model_cpu_offload()
126
+ pipe.enable_sequential_cpu_offload()