ajsbsd commited on
Commit
27d1197
Β·
verified Β·
1 Parent(s): 8e3de44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -82
app.py CHANGED
@@ -39,14 +39,14 @@ class PipelineManager:
39
  self.nsfw_detector_loaded = False
40
 
41
  def clear_memory(self):
42
- """Aggressive memory cleanup"""
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
  torch.cuda.synchronize()
46
  gc.collect()
47
 
48
  def load_nsfw_detector(self) -> bool:
49
- """Load NSFW detection model"""
50
  if self.nsfw_detector_loaded:
51
  return True
52
 
@@ -70,12 +70,13 @@ class PipelineManager:
70
  def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
71
  """
72
  Detects NSFW content using CLIP-based zero-shot classification.
73
- Falls back to prompt-based detection if CLIP model fails.
74
  """
75
  try:
76
  # Load NSFW detector if not already loaded
77
  if not self.nsfw_detector_loaded:
78
  if not self.load_nsfw_detector():
 
79
  return self._fallback_nsfw_detection(prompt)
80
 
81
  # CLIP-based NSFW detection
@@ -132,11 +133,12 @@ class PipelineManager:
132
  return is_nsfw_result, confidence
133
 
134
  except Exception as e:
135
- logger.error(f"NSFW detection error: {e}")
 
136
  return self._fallback_nsfw_detection(prompt)
137
 
138
  def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
139
- """Fallback NSFW detection based on prompt analysis"""
140
  nsfw_keywords = [
141
  'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
142
  'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
@@ -148,13 +150,15 @@ class PipelineManager:
148
  logger.warning(f"🚨 NSFW content detected (prompt-based: '{keyword}' found)")
149
  return True, random.uniform(0.7, 0.95)
150
 
151
- # Random chance for demonstration (remove in production)
152
  if random.random() < 0.02: # 2% chance for demo
153
  logger.warning("🚨 NSFW content detected (random demo detection)")
154
  return True, random.uniform(0.6, 0.8)
155
 
156
  return False, random.uniform(0.1, 0.3)
157
- """Load models with enhanced error handling and memory optimization"""
 
 
158
  if self.model_loaded:
159
  return True
160
 
@@ -176,14 +180,14 @@ class PipelineManager:
176
  torch_dtype=DTYPE,
177
  use_safetensors=True,
178
  variant="fp16" if DEVICE == "cuda" else None,
179
- safety_checker=None, # Disable for faster loading
180
  requires_safety_checker=False
181
  )
182
 
183
- # Memory optimizations
184
  self._optimize_pipeline(self.txt2img_pipe)
185
 
186
- # Create img2img pipeline sharing components
187
  self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
188
  vae=self.txt2img_pipe.vae,
189
  text_encoder=self.txt2img_pipe.text_encoder,
@@ -192,10 +196,11 @@ class PipelineManager:
192
  tokenizer_2=self.txt2img_pipe.tokenizer_2,
193
  unet=self.txt2img_pipe.unet,
194
  scheduler=self.txt2img_pipe.scheduler,
195
- safety_checker=None,
196
  requires_safety_checker=False
197
  )
198
 
 
199
  self._optimize_pipeline(self.img2img_pipe)
200
 
201
  self.model_loaded = True
@@ -208,22 +213,23 @@ class PipelineManager:
208
  return False
209
 
210
  def _optimize_pipeline(self, pipeline):
211
- """Apply memory optimizations to pipeline"""
212
  pipeline.enable_attention_slicing()
213
  pipeline.enable_vae_slicing()
214
 
215
  if DEVICE == "cuda":
216
- # Use sequential CPU offloading for better memory management
217
  pipeline.enable_sequential_cpu_offload()
218
- # Enable memory efficient attention if available
219
  try:
220
  pipeline.enable_xformers_memory_efficient_attention()
221
- except:
222
  logger.info("xformers not available, using default attention")
223
  else:
 
224
  pipeline = pipeline.to(DEVICE)
225
 
226
- # Global pipeline manager
227
  pipe_manager = PipelineManager()
228
 
229
  # Enhanced prompt templates
@@ -246,11 +252,13 @@ EXAMPLE_PROMPTS = [
246
  ]
247
 
248
  def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
249
- """Smart prompt enhancement"""
 
 
250
  if not prompt.strip():
251
  return ""
252
 
253
- # Don't add quality tags if they're already present
254
  if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
255
  return prompt
256
 
@@ -259,22 +267,28 @@ def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
259
  return prompt
260
 
261
  def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
262
- """Ensure SDXL-compatible dimensions with better aspect ratio handling"""
 
 
 
263
  # Round to nearest multiple of 64
264
  width = max(512, min(1024, ((width + 31) // 64) * 64))
265
  height = max(512, min(1024, ((height + 31) // 64) * 64))
266
 
267
  # Ensure reasonable aspect ratios (prevent extremely wide/tall images)
268
  aspect_ratio = width / height
269
- if aspect_ratio > 2.0: # Too wide
270
  height = width // 2
271
- elif aspect_ratio < 0.5: # Too tall
272
  width = height // 2
273
 
274
  return width, height
275
 
276
  def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
277
- """Create PNG with embedded metadata"""
 
 
 
278
  temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
279
 
280
  meta = PngImagePlugin.PngInfo()
@@ -282,7 +296,7 @@ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
282
  if value is not None:
283
  meta.add_text(key, str(value))
284
 
285
- # Add generation timestamp
286
  meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
287
  meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
288
 
@@ -290,7 +304,9 @@ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
290
  return temp_path
291
 
292
  def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
293
- """Format generation information display"""
 
 
294
  info_lines = [
295
  f"βœ… Generated in {generation_time:.1f}s",
296
  f"πŸ“ Resolution: {params.get('width', 'N/A')}Γ—{params.get('height', 'N/A')}",
@@ -305,20 +321,23 @@ def format_generation_info(params: Dict[str, Any], generation_time: float) -> st
305
 
306
  return "\n".join(info_lines)
307
 
308
- @spaces.GPU(duration=120) # Increased duration for model loading
309
  def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
310
  width: int, height: int, seed: int, add_quality: bool) -> Tuple:
311
- """Text-to-image generation with enhanced error handling"""
 
 
 
312
 
313
  if not prompt.strip():
314
- return None, None, "❌ Please enter a prompt"
315
 
316
- # Lazy load models
317
  if not pipe_manager.load_models():
318
  return None, None, "❌ Failed to load model. Please try again."
319
 
320
  try:
321
- pipe_manager.clear_memory()
322
 
323
  # Process parameters
324
  width, height = validate_and_fix_dimensions(width, height)
@@ -328,12 +347,12 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
328
  enhanced_prompt = enhance_prompt(prompt, add_quality)
329
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
330
 
331
- # Generation parameters
332
  gen_params = {
333
  "prompt": enhanced_prompt,
334
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
335
- "num_inference_steps": min(max(steps, 10), 50), # Clamp steps
336
- "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance
337
  "width": width,
338
  "height": height,
339
  "generator": generator,
@@ -348,11 +367,11 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
348
 
349
  generation_time = time.time() - start_time
350
 
351
- # NSFW Detection
352
  is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
353
 
354
  if is_nsfw_result:
355
- # Create a blurred/censored version or return error
356
  blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
357
  warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
358
 
@@ -376,7 +395,7 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
376
 
377
  return blurred_image, png_path, info_text
378
 
379
- # Prepare metadata
380
  metadata = {
381
  "prompt": enhanced_prompt,
382
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
@@ -402,38 +421,42 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
402
  logger.error(f"Generation error: {e}")
403
  return None, None, f"❌ Generation failed: {str(e)}"
404
  finally:
405
- pipe_manager.clear_memory()
406
 
407
- @spaces.GPU(duration=120)
408
  def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
409
  steps: int, guidance_scale: float, strength: float, seed: int,
410
  add_quality: bool) -> Tuple:
411
- """Image-to-image generation with enhanced preprocessing"""
 
 
 
412
 
413
  if input_image is None:
414
- return None, None, "❌ Please upload an input image"
415
 
416
  if not prompt.strip():
417
- return None, None, "❌ Please enter a prompt"
418
 
 
419
  if not pipe_manager.load_models():
420
  return None, None, "❌ Failed to load model. Please try again."
421
 
422
  try:
423
- pipe_manager.clear_memory()
424
 
425
- # Process input image
426
  if input_image.mode != 'RGB':
427
  input_image = input_image.convert('RGB')
428
 
429
- # Smart resizing maintaining aspect ratio
430
  original_size = input_image.size
431
  max_dimension = 1024
432
 
433
  if max(original_size) > max_dimension:
434
  input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
435
 
436
- # Ensure SDXL compatible dimensions
437
  w, h = validate_and_fix_dimensions(*input_image.size)
438
  input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
439
 
@@ -444,14 +467,14 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
444
  enhanced_prompt = enhance_prompt(prompt, add_quality)
445
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
446
 
447
- # Generation parameters
448
  gen_params = {
449
  "prompt": enhanced_prompt,
450
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
451
  "image": input_image,
452
- "num_inference_steps": min(max(steps, 10), 50),
453
- "guidance_scale": max(1.0, min(guidance_scale, 20.0)),
454
- "strength": max(0.1, min(strength, 1.0)),
455
  "generator": generator,
456
  "output_type": "pil"
457
  }
@@ -464,11 +487,11 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
464
 
465
  generation_time = time.time() - start_time
466
 
467
- # NSFW Detection
468
  is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
469
 
470
  if is_nsfw_result:
471
- # Create blurred version for inappropriate content
472
  blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
473
  warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
474
 
@@ -492,7 +515,7 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
492
 
493
  return blurred_image, png_path, info_text
494
 
495
- # Prepare metadata
496
  metadata = {
497
  "prompt": enhanced_prompt,
498
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
@@ -518,15 +541,18 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
518
  logger.error(f"Generation error: {e}")
519
  return None, None, f"❌ Generation failed: {str(e)}"
520
  finally:
521
- pipe_manager.clear_memory()
522
 
523
  def get_random_prompt():
524
- """Get a random example prompt"""
525
  return random.choice(EXAMPLE_PROMPTS)
526
 
527
  # Enhanced Gradio interface
528
  def create_interface():
529
- """Create the Gradio interface"""
 
 
 
530
 
531
  with gr.Blocks(
532
  title="CyberRealistic Pony - SDXL Generator",
@@ -729,42 +755,17 @@ def create_interface():
729
 
730
  return demo
731
 
732
- # Initialize and launch
733
  if __name__ == "__main__":
734
  logger.info(f"πŸš€ Initializing CyberRealistic Pony Generator on {DEVICE}")
735
  logger.info(f"πŸ“± PyTorch version: {torch.__version__}")
736
  logger.info(f"πŸ›‘οΈ NSFW Content Filter: Enabled")
737
 
738
  demo = create_interface()
739
- demo.queue(max_size=20) # Enable queuing for better UX
740
  demo.launch(
741
  server_name="0.0.0.0",
742
  server_port=7860,
743
  show_error=True,
744
- share=False # Set to True if you want a public link
745
  )
746
-
747
- # Example prompt buttons
748
- txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
749
- img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])
750
-
751
- # Clear buttons
752
- txt_clear_btn.click(lambda: "", outputs=[txt_prompt])
753
- img_clear_btn.click(lambda: "", outputs=[img_prompt])
754
-
755
- return demo
756
-
757
- # Initialize and launch
758
- if __name__ == "__main__":
759
- logger.info(f"πŸš€ Initializing CyberRealistic Pony Generator on {DEVICE}")
760
- logger.info(f"πŸ“± PyTorch version: {torch.__version__}")
761
- logger.info(f"πŸ›‘οΈ NSFW Content Filter: Enabled")
762
-
763
- demo = create_interface()
764
- demo.queue(max_size=20) # Enable queuing for better UX
765
- demo.launch(
766
- server_name="0.0.0.0",
767
- server_port=7860,
768
- show_error=True,
769
- share=False # Set to True if you want a public link
770
- )
 
39
  self.nsfw_detector_loaded = False
40
 
41
  def clear_memory(self):
42
+ """Aggressive memory cleanup to free up GPU/CPU memory."""
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
  torch.cuda.synchronize()
46
  gc.collect()
47
 
48
  def load_nsfw_detector(self) -> bool:
49
+ """Load NSFW detection model (CLIP) with error handling."""
50
  if self.nsfw_detector_loaded:
51
  return True
52
 
 
70
  def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
71
  """
72
  Detects NSFW content using CLIP-based zero-shot classification.
73
+ Falls back to prompt-based detection if CLIP model fails or is not loaded.
74
  """
75
  try:
76
  # Load NSFW detector if not already loaded
77
  if not self.nsfw_detector_loaded:
78
  if not self.load_nsfw_detector():
79
+ # If NSFW detector cannot be loaded, fall back to prompt-based
80
  return self._fallback_nsfw_detection(prompt)
81
 
82
  # CLIP-based NSFW detection
 
133
  return is_nsfw_result, confidence
134
 
135
  except Exception as e:
136
+ logger.error(f"NSFW detection error (CLIP model failed): {e}")
137
+ # Fallback to prompt-based detection if CLIP model encounters an error
138
  return self._fallback_nsfw_detection(prompt)
139
 
140
  def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
141
+ """Fallback NSFW detection based on prompt keyword analysis."""
142
  nsfw_keywords = [
143
  'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
144
  'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
 
150
  logger.warning(f"🚨 NSFW content detected (prompt-based: '{keyword}' found)")
151
  return True, random.uniform(0.7, 0.95)
152
 
153
+ # Random chance for demonstration (consider removing in production)
154
  if random.random() < 0.02: # 2% chance for demo
155
  logger.warning("🚨 NSFW content detected (random demo detection)")
156
  return True, random.uniform(0.6, 0.8)
157
 
158
  return False, random.uniform(0.1, 0.3)
159
+
160
+ def load_models(self) -> bool:
161
+ """Load Stable Diffusion XL models (txt2img and img2img) with enhanced error handling and memory optimization."""
162
  if self.model_loaded:
163
  return True
164
 
 
180
  torch_dtype=DTYPE,
181
  use_safetensors=True,
182
  variant="fp16" if DEVICE == "cuda" else None,
183
+ safety_checker=None, # Disable for faster loading, using custom NSFW check
184
  requires_safety_checker=False
185
  )
186
 
187
+ # Apply memory optimizations to txt2img pipeline
188
  self._optimize_pipeline(self.txt2img_pipe)
189
 
190
+ # Create img2img pipeline sharing components to save memory
191
  self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
192
  vae=self.txt2img_pipe.vae,
193
  text_encoder=self.txt2img_pipe.text_encoder,
 
196
  tokenizer_2=self.txt2img_pipe.tokenizer_2,
197
  unet=self.txt2img_pipe.unet,
198
  scheduler=self.txt2img_pipe.scheduler,
199
+ safety_checker=None, # Disable for faster loading, using custom NSFW check
200
  requires_safety_checker=False
201
  )
202
 
203
+ # Apply memory optimizations to img2img pipeline
204
  self._optimize_pipeline(self.img2img_pipe)
205
 
206
  self.model_loaded = True
 
213
  return False
214
 
215
  def _optimize_pipeline(self, pipeline):
216
+ """Apply memory optimizations to a given diffusion pipeline."""
217
  pipeline.enable_attention_slicing()
218
  pipeline.enable_vae_slicing()
219
 
220
  if DEVICE == "cuda":
221
+ # Use sequential CPU offloading for better memory management on GPU
222
  pipeline.enable_sequential_cpu_offload()
223
+ # Enable memory efficient attention if xformers is available
224
  try:
225
  pipeline.enable_xformers_memory_efficient_attention()
226
+ except Exception: # Catch any error if xformers is not installed/configured
227
  logger.info("xformers not available, using default attention")
228
  else:
229
+ # Move pipeline to CPU if CUDA is not available
230
  pipeline = pipeline.to(DEVICE)
231
 
232
+ # Global pipeline manager instance
233
  pipe_manager = PipelineManager()
234
 
235
  # Enhanced prompt templates
 
252
  ]
253
 
254
  def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
255
+ """
256
+ Enhances the given prompt with quality tags unless they are already present.
257
+ """
258
  if not prompt.strip():
259
  return ""
260
 
261
+ # Don't add quality tags if they're already present in the prompt (case-insensitive)
262
  if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
263
  return prompt
264
 
 
267
  return prompt
268
 
269
  def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
270
+ """
271
+ Ensures SDXL-compatible dimensions (multiples of 64) and reasonable aspect ratios.
272
+ Clamps dimensions between 512 and 1024.
273
+ """
274
  # Round to nearest multiple of 64
275
  width = max(512, min(1024, ((width + 31) // 64) * 64))
276
  height = max(512, min(1024, ((height + 31) // 64) * 64))
277
 
278
  # Ensure reasonable aspect ratios (prevent extremely wide/tall images)
279
  aspect_ratio = width / height
280
+ if aspect_ratio > 2.0: # Too wide, adjust height
281
  height = width // 2
282
+ elif aspect_ratio < 0.5: # Too tall, adjust width
283
  width = height // 2
284
 
285
  return width, height
286
 
287
  def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
288
+ """
289
+ Creates a temporary PNG file with embedded metadata from the generation parameters.
290
+ Returns the path to the created PNG file.
291
+ """
292
  temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
293
 
294
  meta = PngImagePlugin.PngInfo()
 
296
  if value is not None:
297
  meta.add_text(key, str(value))
298
 
299
+ # Add generation timestamp and model info
300
  meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
301
  meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
302
 
 
304
  return temp_path
305
 
306
  def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
307
+ """
308
+ Formats the generation information into a human-readable string for display.
309
+ """
310
  info_lines = [
311
  f"βœ… Generated in {generation_time:.1f}s",
312
  f"πŸ“ Resolution: {params.get('width', 'N/A')}Γ—{params.get('height', 'N/A')}",
 
321
 
322
  return "\n".join(info_lines)
323
 
324
+ @spaces.GPU(duration=120) # Increased duration for model loading and generation
325
  def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
326
  width: int, height: int, seed: int, add_quality: bool) -> Tuple:
327
+ """
328
+ Handles text-to-image generation, including parameter processing, model inference,
329
+ NSFW detection, and metadata creation.
330
+ """
331
 
332
  if not prompt.strip():
333
+ return None, None, "❌ Please enter a prompt."
334
 
335
+ # Lazy load models if not already loaded
336
  if not pipe_manager.load_models():
337
  return None, None, "❌ Failed to load model. Please try again."
338
 
339
  try:
340
+ pipe_manager.clear_memory() # Clear memory before generation
341
 
342
  # Process parameters
343
  width, height = validate_and_fix_dimensions(width, height)
 
347
  enhanced_prompt = enhance_prompt(prompt, add_quality)
348
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
349
 
350
+ # Generation parameters dictionary
351
  gen_params = {
352
  "prompt": enhanced_prompt,
353
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
354
+ "num_inference_steps": min(max(steps, 10), 50), # Clamp steps to a reasonable range
355
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale
356
  "width": width,
357
  "height": height,
358
  "generator": generator,
 
367
 
368
  generation_time = time.time() - start_time
369
 
370
+ # Perform NSFW Detection on the generated image
371
  is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
372
 
373
  if is_nsfw_result:
374
+ # If NSFW, blur the image and return a warning message
375
  blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
376
  warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
377
 
 
395
 
396
  return blurred_image, png_path, info_text
397
 
398
+ # If not NSFW, prepare metadata and save the original image
399
  metadata = {
400
  "prompt": enhanced_prompt,
401
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
 
421
  logger.error(f"Generation error: {e}")
422
  return None, None, f"❌ Generation failed: {str(e)}"
423
  finally:
424
+ pipe_manager.clear_memory() # Ensure memory is cleared even if an error occurs
425
 
426
+ @spaces.GPU(duration=120) # Increased duration for model loading and generation
427
  def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
428
  steps: int, guidance_scale: float, strength: float, seed: int,
429
  add_quality: bool) -> Tuple:
430
+ """
431
+ Handles image-to-image generation, including image preprocessing, parameter processing,
432
+ model inference, NSFW detection, and metadata creation.
433
+ """
434
 
435
  if input_image is None:
436
+ return None, None, "❌ Please upload an input image."
437
 
438
  if not prompt.strip():
439
+ return None, None, "❌ Please enter a prompt."
440
 
441
+ # Lazy load models if not already loaded
442
  if not pipe_manager.load_models():
443
  return None, None, "❌ Failed to load model. Please try again."
444
 
445
  try:
446
+ pipe_manager.clear_memory() # Clear memory before generation
447
 
448
+ # Process input image: convert to RGB if necessary
449
  if input_image.mode != 'RGB':
450
  input_image = input_image.convert('RGB')
451
 
452
+ # Smart resizing maintaining aspect ratio to fit within max_dimension
453
  original_size = input_image.size
454
  max_dimension = 1024
455
 
456
  if max(original_size) > max_dimension:
457
  input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
458
 
459
+ # Ensure SDXL compatible dimensions (multiples of 64)
460
  w, h = validate_and_fix_dimensions(*input_image.size)
461
  input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
462
 
 
467
  enhanced_prompt = enhance_prompt(prompt, add_quality)
468
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
469
 
470
+ # Generation parameters dictionary
471
  gen_params = {
472
  "prompt": enhanced_prompt,
473
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
474
  "image": input_image,
475
+ "num_inference_steps": min(max(steps, 10), 50), # Clamp steps
476
+ "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale
477
+ "strength": max(0.1, min(strength, 1.0)), # Clamp strength
478
  "generator": generator,
479
  "output_type": "pil"
480
  }
 
487
 
488
  generation_time = time.time() - start_time
489
 
490
+ # Perform NSFW Detection on the transformed image
491
  is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
492
 
493
  if is_nsfw_result:
494
+ # If NSFW, blur the image and return a warning message
495
  blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
496
  warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
497
 
 
515
 
516
  return blurred_image, png_path, info_text
517
 
518
+ # If not NSFW, prepare metadata and save the original image
519
  metadata = {
520
  "prompt": enhanced_prompt,
521
  "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
 
541
  logger.error(f"Generation error: {e}")
542
  return None, None, f"❌ Generation failed: {str(e)}"
543
  finally:
544
+ pipe_manager.clear_memory() # Ensure memory is cleared even if an error occurs
545
 
546
  def get_random_prompt():
547
+ """Returns a random example prompt from a predefined list."""
548
  return random.choice(EXAMPLE_PROMPTS)
549
 
550
  # Enhanced Gradio interface
551
  def create_interface():
552
+ """
553
+ Creates and returns the Gradio Blocks interface for the CyberRealistic Pony Generator.
554
+ This includes tabs for Text-to-Image and Image-to-Image, along with controls and outputs.
555
+ """
556
 
557
  with gr.Blocks(
558
  title="CyberRealistic Pony - SDXL Generator",
 
755
 
756
  return demo
757
 
758
+ # Initialize and launch the Gradio application
759
  if __name__ == "__main__":
760
  logger.info(f"πŸš€ Initializing CyberRealistic Pony Generator on {DEVICE}")
761
  logger.info(f"πŸ“± PyTorch version: {torch.__version__}")
762
  logger.info(f"πŸ›‘οΈ NSFW Content Filter: Enabled")
763
 
764
  demo = create_interface()
765
+ demo.queue(max_size=20) # Enable queuing for better user experience
766
  demo.launch(
767
  server_name="0.0.0.0",
768
  server_port=7860,
769
  show_error=True,
770
+ share=False # Set to True if you want a public link (e.g., for Hugging Face Spaces)
771
  )