ajsbsd commited on
Commit
8e3de44
Β·
verified Β·
1 Parent(s): fd21bc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -719
app.py CHANGED
@@ -245,725 +245,6 @@ EXAMPLE_PROMPTS = [
245
  "steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting"
246
  ]
247
 
248
- def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
249
- """Smart prompt enhancement"""
250
- if not prompt.strip():
251
- return ""
252
-
253
- # Don't add quality tags if they're already present
254
- if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
255
- return prompt
256
-
257
- if add_quality:
258
- return f"{QUALITY_TAGS}, {prompt}"
259
- return prompt
260
-
261
- def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
262
- """Ensure SDXL-compatible dimensions with better aspect ratio handling"""
263
- # Round to nearest multiple of 64
264
- width = max(512, min(1024, ((width + 31) // 64) * 64))
265
- height = max(512, min(1024, ((height + 31) // 64) * 64))
266
-
267
- # Ensure reasonable aspect ratios (prevent extremely wide/tall images)
268
- aspect_ratio = width / height
269
- if aspect_ratio > 2.0: # Too wide
270
- height = width // 2
271
- elif aspect_ratio < 0.5: # Too tall
272
- width = height // 2
273
-
274
- return width, height
275
-
276
- def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
277
- """Create PNG with embedded metadata"""
278
- temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
279
-
280
- meta = PngImagePlugin.PngInfo()
281
- for key, value in params.items():
282
- if value is not None:
283
- meta.add_text(key, str(value))
284
-
285
- # Add generation timestamp
286
- meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
287
- meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
288
-
289
- image.save(temp_path, "PNG", pnginfo=meta, optimize=True)
290
- return temp_path
291
-
292
- def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
293
- """Format generation information display"""
294
- info_lines = [
295
- f"βœ… Generated in {generation_time:.1f}s",
296
- f"πŸ“ Resolution: {params.get('width', 'N/A')}Γ—{params.get('height', 'N/A')}",
297
- f"🎯 Prompt: {params.get('prompt', '')[:60]}{'...' if len(params.get('prompt', '')) > 60 else ''}",
298
- f"🚫 Negative: {params.get('negative_prompt', 'None')[:40]}{'...' if len(params.get('negative_prompt', '')) > 40 else ''}",
299
- f"🎲 Seed: {params.get('seed', 'N/A')}",
300
- f"πŸ“Š Steps: {params.get('steps', 'N/A')} | CFG: {params.get('guidance_scale', 'N/A')}"
301
- ]
302
-
303
- if 'strength' in params:
304
- info_lines.append(f"πŸ’ͺ Strength: {params['strength']}")
305
-
306
- return "\n".join(info_lines)
307
-
308
- @spaces.GPU(duration=120) # Increased duration for model loading
309
- def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
310
- width: int, height: int, seed: int, add_quality: bool) -> Tuple:
311
- """Text-to-image generation with enhanced error handling"""
312
-
313
- if not prompt.strip():
314
- return None, None, "❌ Please enter a prompt"
315
-
316
- # Lazy load models
317
- if not pipe_manager.load_models(): # <--- Change from load_models() to _load_models()
318
- return None, None, "❌ Failed to load model. Please try again."
319
-
320
- try:
321
- pipe_manager.clear_memory()
322
-
323
- # Process parameters
324
- width, height = validate_and_fix_dimensions(width, height)
325
- if seed == -1:
326
- seed = random.randint(0, MAX_SEED)
327
-
328
- enhanced_prompt = enhance_prompt(prompt, add_quality)
329
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
330
-
331
- # Generation parameters
332
- gen_params = {
333
- "prompt": enhanced_prompt,
334
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
335
- "num_inference_steps": min(max(steps, 10), 50), # Clamp steps
336
- "guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance
337
- "width": width,
338
- "height": height,
339
- "generator": generator,
340
- "output_type": "pil"
341
- }
342
-
343
- logger.info(f"Generating: {enhanced_prompt[:50]}...")
344
- start_time = time.time()
345
-
346
- with torch.inference_mode():
347
- result = pipe_manager.txt2img_pipe(**gen_params)
348
-
349
- generation_time = time.time() - start_time
350
-
351
- # NSFW Detection
352
- is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
353
-
354
- if is_nsfw_result:
355
- # Create a blurred/censored version or return error
356
- blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
357
- warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
358
-
359
- # Still save metadata but mark as filtered
360
- metadata = {
361
- "prompt": enhanced_prompt,
362
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
363
- "steps": gen_params["num_inference_steps"],
364
- "guidance_scale": gen_params["guidance_scale"],
365
- "width": width,
366
- "height": height,
367
- "seed": seed,
368
- "sampler": "Euler Ancestral",
369
- "model_hash": "cyberrealistic_pony_v110",
370
- "nsfw_filtered": "true",
371
- "nsfw_confidence": f"{nsfw_confidence:.3f}"
372
- }
373
-
374
- png_path = create_metadata_png(blurred_image, metadata)
375
- info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
376
-
377
- return blurred_image, png_path, info_text
378
-
379
- # Prepare metadata
380
- metadata = {
381
- "prompt": enhanced_prompt,
382
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
383
- "steps": gen_params["num_inference_steps"],
384
- "guidance_scale": gen_params["guidance_scale"],
385
- "width": width,
386
- "height": height,
387
- "seed": seed,
388
- "sampler": "Euler Ancestral",
389
- "model_hash": "cyberrealistic_pony_v110"
390
- }
391
-
392
- # Save with metadata
393
- png_path = create_metadata_png(result.images[0], metadata)
394
- info_text = format_generation_info(metadata, generation_time)
395
-
396
- return result.images[0], png_path, info_text
397
-
398
- except torch.cuda.OutOfMemoryError:
399
- pipe_manager.clear_memory()
400
- return None, None, "❌ GPU out of memory. Try smaller dimensions or fewer steps."
401
- except Exception as e:
402
- logger.error(f"Generation error: {e}")
403
- return None, None, f"❌ Generation failed: {str(e)}"
404
- finally:
405
- pipe_manager.clear_memory()
406
-
407
- @spaces.GPU(duration=120)
408
- def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
409
- steps: int, guidance_scale: float, strength: float, seed: int,
410
- add_quality: bool) -> Tuple:
411
- """Image-to-image generation with enhanced preprocessing"""
412
-
413
- if input_image is None:
414
- return None, None, "❌ Please upload an input image"
415
-
416
- if not prompt.strip():
417
- return None, None, "❌ Please enter a prompt"
418
-
419
- if not pipe_manager.load_models(): # <--- Change from load_models() to _load_models()
420
- return None, None, "❌ Failed to load model. Please try again."
421
-
422
- try:
423
- pipe_manager.clear_memory()
424
-
425
- # Process input image
426
- if input_image.mode != 'RGB':
427
- input_image = input_image.convert('RGB')
428
-
429
- # Smart resizing maintaining aspect ratio
430
- original_size = input_image.size
431
- max_dimension = 1024
432
-
433
- if max(original_size) > max_dimension:
434
- input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
435
-
436
- # Ensure SDXL compatible dimensions
437
- w, h = validate_and_fix_dimensions(*input_image.size)
438
- input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
439
-
440
- # Process other parameters
441
- if seed == -1:
442
- seed = random.randint(0, MAX_SEED)
443
-
444
- enhanced_prompt = enhance_prompt(prompt, add_quality)
445
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
446
-
447
- # Generation parameters
448
- gen_params = {
449
- "prompt": enhanced_prompt,
450
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
451
- "image": input_image,
452
- "num_inference_steps": min(max(steps, 10), 50),
453
- "guidance_scale": max(1.0, min(guidance_scale, 20.0)),
454
- "strength": max(0.1, min(strength, 1.0)),
455
- "generator": generator,
456
- "output_type": "pil"
457
- }
458
-
459
- logger.info(f"Transforming: {enhanced_prompt[:50]}...")
460
- start_time = time.time()
461
-
462
- with torch.inference_mode():
463
- result = pipe_manager.img2img_pipe(**gen_params)
464
-
465
- generation_time = time.time() - start_time
466
-
467
- # NSFW Detection
468
- is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
469
-
470
- if is_nsfw_result:
471
- # Create blurred version for inappropriate content
472
- blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
473
- warning_msg = f"⚠️ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
474
-
475
- metadata = {
476
- "prompt": enhanced_prompt,
477
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
478
- "steps": gen_params["num_inference_steps"],
479
- "guidance_scale": gen_params["guidance_scale"],
480
- "strength": gen_params["strength"],
481
- "width": w,
482
- "height": h,
483
- "seed": seed,
484
- "sampler": "Euler Ancestral",
485
- "model_hash": "cyberrealistic_pony_v110",
486
- "nsfw_filtered": "true",
487
- "nsfw_confidence": f"{nsfw_confidence:.3f}"
488
- }
489
-
490
- png_path = create_metadata_png(blurred_image, metadata)
491
- info_text = f"{warning_msg}\n\n{format_generation_info(metadata, generation_time)}"
492
-
493
- return blurred_image, png_path, info_text
494
-
495
- # Prepare metadata
496
- metadata = {
497
- "prompt": enhanced_prompt,
498
- "negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
499
- "steps": gen_params["num_inference_steps"],
500
- "guidance_scale": gen_params["guidance_scale"],
501
- "strength": gen_params["strength"],
502
- "width": w,
503
- "height": h,
504
- "seed": seed,
505
- "sampler": "Euler Ancestral",
506
- "model_hash": "cyberrealistic_pony_v110"
507
- }
508
-
509
- png_path = create_metadata_png(result.images[0], metadata)
510
- info_text = format_generation_info(metadata, generation_time)
511
-
512
- return result.images[0], png_path, info_text
513
-
514
- except torch.cuda.OutOfMemoryError:
515
- pipe_manager.clear_memory()
516
- return None, None, "❌ GPU out of memory. Try lower strength or fewer steps."
517
- except Exception as e:
518
- logger.error(f"Generation error: {e}")
519
- return None, None, f"❌ Generation failed: {str(e)}"
520
- finally:
521
- pipe_manager.clear_memory()
522
-
523
- def get_random_prompt():
524
- """Get a random example prompt"""
525
- return random.choice(EXAMPLE_PROMPTS)
526
-
527
- # Enhanced Gradio interface
528
- def create_interface():
529
- """Create the Gradio interface"""
530
-
531
- with gr.Blocks(
532
- title="CyberRealistic Pony - SDXL Generator",
533
- theme=gr.themes.Soft(primary_hue="blue"),
534
- css="""
535
- .generate-btn {
536
- background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
537
- border: none !important;
538
- }
539
- .generate-btn:hover {
540
- transform: translateY(-2px);
541
- box-shadow: 0 4px 12px rgba(0,0,0,0.2);
542
- }
543
- """
544
- ) as demo:
545
-
546
- gr.Markdown("""
547
- # 🎨 CyberRealistic Pony Generator
548
-
549
- **High-quality SDXL image generation** β€’ Optimized for HuggingFace Spaces β€’ **NSFW Content Filter Enabled**
550
-
551
- > ⚑ **First generation takes longer** (model loading) β€’ πŸ“‹ **Metadata embedded** in all outputs β€’ πŸ›‘οΈ **Content filtered for safety**
552
- """)
553
-
554
- with gr.Tabs():
555
- # Text to Image Tab
556
- with gr.TabItem("🎨 Text to Image", id="txt2img"):
557
- with gr.Row():
558
- with gr.Column(scale=1):
559
- with gr.Group():
560
- txt_prompt = gr.Textbox(
561
- label="✨ Prompt",
562
- placeholder="A beautiful landscape with mountains and sunset...",
563
- lines=3,
564
- max_lines=5
565
- )
566
-
567
- with gr.Row():
568
- txt_example_btn = gr.Button("🎲 Random", size="sm")
569
- txt_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
570
-
571
- with gr.Accordion("βš™οΈ Advanced Settings", open=False):
572
- txt_negative = gr.Textbox(
573
- label="❌ Negative Prompt",
574
- value=DEFAULT_NEGATIVE,
575
- lines=2,
576
- max_lines=3
577
- )
578
-
579
- txt_quality = gr.Checkbox(
580
- label="✨ Add Quality Tags",
581
- value=True,
582
- info="Automatically enhance prompt with quality tags"
583
- )
584
-
585
- with gr.Row():
586
- txt_steps = gr.Slider(
587
- 10, 50, 25, step=1,
588
- label="πŸ“Š Steps",
589
- info="More steps = better quality, slower generation"
590
- )
591
- txt_guidance = gr.Slider(
592
- 1.0, 15.0, 7.5, step=0.5,
593
- label="πŸŽ›οΈ CFG Scale",
594
- info="How closely to follow the prompt"
595
- )
596
-
597
- with gr.Row():
598
- txt_width = gr.Slider(
599
- 512, 1024, 768, step=64,
600
- label="πŸ“ Width"
601
- )
602
- txt_height = gr.Slider(
603
- 512, 1024, 768, step=64,
604
- label="πŸ“ Height"
605
- )
606
-
607
- txt_seed = gr.Slider(
608
- -1, MAX_SEED, -1, step=1,
609
- label="🎲 Seed (-1 = random)",
610
- info="Use same seed for reproducible results"
611
- )
612
-
613
- txt_generate_btn = gr.Button(
614
- "🎨 Generate Image",
615
- variant="primary",
616
- size="lg",
617
- elem_classes=["generate-btn"]
618
- )
619
-
620
- with gr.Column(scale=1):
621
- txt_output_image = gr.Image(
622
- label="πŸ–ΌοΈ Generated Image",
623
- height=500,
624
- show_download_button=True
625
- )
626
- txt_download_file = gr.File(
627
- label="πŸ“₯ Download PNG (with metadata)",
628
- file_types=[".png"]
629
- )
630
- txt_info = gr.Textbox(
631
- label="ℹ️ Generation Info",
632
- lines=6,
633
- max_lines=8,
634
- interactive=False
635
- )
636
-
637
- # Image to Image Tab
638
- with gr.TabItem("πŸ–ΌοΈ Image to Image", id="img2img"):
639
- with gr.Row():
640
- with gr.Column(scale=1):
641
- img_input = gr.Image(
642
- label="πŸ“€ Input Image",
643
- type="pil",
644
- height=300
645
- )
646
-
647
- with gr.Group():
648
- img_prompt = gr.Textbox(
649
- label="✨ Transformation Prompt",
650
- placeholder="digital art style, vibrant colors...",
651
- lines=3
652
- )
653
-
654
- with gr.Row():
655
- img_example_btn = gr.Button("🎲 Random", size="sm")
656
- img_clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
657
-
658
- with gr.Accordion("βš™οΈ Advanced Settings", open=False):
659
- img_negative = gr.Textbox(
660
- label="❌ Negative Prompt",
661
- value=DEFAULT_NEGATIVE,
662
- lines=2
663
- )
664
-
665
- img_quality = gr.Checkbox(
666
- label="✨ Add Quality Tags",
667
- value=True
668
- )
669
-
670
- with gr.Row():
671
- img_steps = gr.Slider(10, 50, 25, step=1, label="πŸ“Š Steps")
672
- img_guidance = gr.Slider(1.0, 15.0, 7.5, step=0.5, label="πŸŽ›οΈ CFG")
673
-
674
- img_strength = gr.Slider(
675
- 0.1, 1.0, 0.75, step=0.05,
676
- label="πŸ’ͺ Transformation Strength",
677
- info="Higher = more creative, lower = more faithful to input"
678
- )
679
-
680
- img_seed = gr.Slider(-1, MAX_SEED, -1, step=1, label="🎲 Seed")
681
-
682
- img_generate_btn = gr.Button(
683
- "πŸ–ΌοΈ Transform Image",
684
- variant="primary",
685
- size="lg",
686
- elem_classes=["generate-btn"]
687
- )
688
-
689
- with gr.Column(scale=1):
690
- img_output_image = gr.Image(
691
- label="πŸ–ΌοΈ Transformed Image",
692
- height=500,
693
- show_download_button=True
694
- )
695
- img_download_file = gr.File(
696
- label="πŸ“₯ Download PNG (with metadata)",
697
- file_types=[".png"]
698
- )
699
- img_info = gr.Textbox(
700
- label="ℹ️ Generation Info",
701
- lines=6,
702
- interactive=False
703
- )
704
-
705
- # Event handlers
706
- txt_generate_btn.click(
707
- fn=generate_txt2img,
708
- inputs=[txt_prompt, txt_negative, txt_steps, txt_guidance,
709
- txt_width, txt_height, txt_seed, txt_quality],
710
- outputs=[txt_output_image, txt_download_file, txt_info],
711
- show_progress=True
712
- )
713
-
714
- img_generate_btn.click(
715
- fn=generate_img2img,
716
- inputs=[img_input, img_prompt, img_negative, img_steps, img_guidance,
717
- img_strength, img_seed, img_quality],
718
- outputs=[img_output_image, img_download_file, img_info],
719
- show_progress=True
720
- )import gradio as gr
721
- import torch
722
- from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler
723
- from PIL import Image, PngImagePlugin, ImageFilter
724
- from datetime import datetime
725
- import os
726
- import gc
727
- import time
728
- import spaces
729
- from typing import Optional, Tuple, Dict, Any
730
- from huggingface_hub import hf_hub_download
731
- import tempfile
732
- import random
733
- import logging
734
- import torch.nn.functional as F
735
- from transformers import CLIPProcessor, CLIPModel
736
-
737
- # Configure logging
738
- logging.basicConfig(level=logging.INFO)
739
- logger = logging.getLogger(__name__)
740
-
741
- # Constants
742
- MODEL_REPO = "ajsbsd/CyberRealistic-Pony"
743
- MODEL_FILENAME = "cyberrealisticPony_v110.safetensors"
744
- NSFW_MODEL_ID = "openai/clip-vit-base-patch32" # CLIP model for NSFW detection
745
- MAX_SEED = 2**32 - 1
746
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
747
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
748
- NSFW_THRESHOLD = 0.25 # Threshold for NSFW detection
749
-
750
- # Global pipeline state
751
- class PipelineManager:
752
- def __init__(self):
753
- self.txt2img_pipe = None
754
- self.img2img_pipe = None
755
- self.nsfw_detector_model = None
756
- self.nsfw_detector_processor = None
757
- self.model_loaded = False
758
- self.nsfw_detector_loaded = False
759
-
760
- def clear_memory(self):
761
- """Aggressive memory cleanup"""
762
- if torch.cuda.is_available():
763
- torch.cuda.empty_cache()
764
- torch.cuda.synchronize()
765
- gc.collect()
766
-
767
- def load_nsfw_detector(self) -> bool:
768
- """Load NSFW detection model"""
769
- if self.nsfw_detector_loaded:
770
- return True
771
-
772
- try:
773
- logger.info("Loading NSFW detector...")
774
- self.nsfw_detector_processor = CLIPProcessor.from_pretrained(NSFW_MODEL_ID)
775
- self.nsfw_detector_model = CLIPModel.from_pretrained(NSFW_MODEL_ID)
776
-
777
- if DEVICE == "cuda":
778
- self.nsfw_detector_model = self.nsfw_detector_model.to(DEVICE)
779
-
780
- self.nsfw_detector_loaded = True
781
- logger.info("NSFW detector loaded successfully!")
782
- return True
783
-
784
- except Exception as e:
785
- logger.error(f"Failed to load NSFW detector: {e}")
786
- self.nsfw_detector_loaded = False
787
- return False
788
-
789
- def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
790
- """
791
- Detects NSFW content using CLIP-based zero-shot classification.
792
- Falls back to prompt-based detection if CLIP model fails.
793
- """
794
- try:
795
- # Load NSFW detector if not already loaded
796
- if not self.nsfw_detector_loaded:
797
- if not self.load_nsfw_detector():
798
- return self._fallback_nsfw_detection(prompt)
799
-
800
- # CLIP-based NSFW detection
801
- inputs = self.nsfw_detector_processor(images=image, return_tensors="pt").to(DEVICE)
802
-
803
- with torch.no_grad():
804
- image_features = self.nsfw_detector_model.get_image_features(**inputs)
805
-
806
- # Define text prompts for classification
807
- safe_prompts = [
808
- "a safe family-friendly image",
809
- "a general photo",
810
- "appropriate content",
811
- "artistic photography"
812
- ]
813
- unsafe_prompts = [
814
- "explicit adult content",
815
- "nudity",
816
- "inappropriate sexual content",
817
- "pornographic material"
818
- ]
819
-
820
- # Get text features
821
- safe_inputs = self.nsfw_detector_processor(
822
- text=safe_prompts, return_tensors="pt", padding=True
823
- ).to(DEVICE)
824
- unsafe_inputs = self.nsfw_detector_processor(
825
- text=unsafe_prompts, return_tensors="pt", padding=True
826
- ).to(DEVICE)
827
-
828
- safe_features = self.nsfw_detector_model.get_text_features(**safe_inputs)
829
- unsafe_features = self.nsfw_detector_model.get_text_features(**unsafe_inputs)
830
-
831
- # Normalize features for cosine similarity
832
- image_features = F.normalize(image_features, p=2, dim=-1)
833
- safe_features = F.normalize(safe_features, p=2, dim=-1)
834
- unsafe_features = F.normalize(unsafe_features, p=2, dim=-1)
835
-
836
- # Calculate similarities
837
- safe_similarity = (image_features @ safe_features.T).mean().item()
838
- unsafe_similarity = (image_features @ unsafe_features.T).mean().item()
839
-
840
- # Classification logic
841
- is_nsfw_result = (
842
- unsafe_similarity > safe_similarity and
843
- unsafe_similarity > NSFW_THRESHOLD
844
- )
845
-
846
- confidence = unsafe_similarity if is_nsfw_result else safe_similarity
847
-
848
- if is_nsfw_result:
849
- logger.warning(f"🚨 NSFW content detected (CLIP-based: {unsafe_similarity:.3f} > {safe_similarity:.3f})")
850
-
851
- return is_nsfw_result, confidence
852
-
853
- except Exception as e:
854
- logger.error(f"NSFW detection error: {e}")
855
- return self._fallback_nsfw_detection(prompt)
856
-
857
- def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
858
- """Fallback NSFW detection based on prompt analysis"""
859
- nsfw_keywords = [
860
- 'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
861
- 'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
862
- ]
863
-
864
- prompt_lower = prompt.lower()
865
- for keyword in nsfw_keywords:
866
- if keyword in prompt_lower:
867
- logger.warning(f"🚨 NSFW content detected (prompt-based: '{keyword}' found)")
868
- return True, random.uniform(0.7, 0.95)
869
-
870
- # Random chance for demonstration (remove in production)
871
- if random.random() < 0.02: # 2% chance for demo
872
- logger.warning("🚨 NSFW content detected (random demo detection)")
873
- return True, random.uniform(0.6, 0.8)
874
-
875
- return False, random.uniform(0.1, 0.3)
876
- """Load models with enhanced error handling and memory optimization"""
877
- if self.model_loaded:
878
- return True
879
-
880
- try:
881
- logger.info("Loading CyberRealistic Pony models...")
882
-
883
- # Download model with better error handling
884
- model_path = hf_hub_download(
885
- repo_id=MODEL_REPO,
886
- filename=MODEL_FILENAME,
887
- cache_dir=os.environ.get("HF_CACHE_DIR", "/tmp/hf_cache"),
888
- resume_download=True
889
- )
890
- logger.info(f"Model downloaded to: {model_path}")
891
-
892
- # Load txt2img pipeline with optimizations
893
- self.txt2img_pipe = StableDiffusionXLPipeline.from_single_file(
894
- model_path,
895
- torch_dtype=DTYPE,
896
- use_safetensors=True,
897
- variant="fp16" if DEVICE == "cuda" else None,
898
- safety_checker=None, # Disable for faster loading
899
- requires_safety_checker=False
900
- )
901
-
902
- # Memory optimizations
903
- self._optimize_pipeline(self.txt2img_pipe)
904
-
905
- # Create img2img pipeline sharing components
906
- self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
907
- vae=self.txt2img_pipe.vae,
908
- text_encoder=self.txt2img_pipe.text_encoder,
909
- text_encoder_2=self.txt2img_pipe.text_encoder_2,
910
- tokenizer=self.txt2img_pipe.tokenizer,
911
- tokenizer_2=self.txt2img_pipe.tokenizer_2,
912
- unet=self.txt2img_pipe.unet,
913
- scheduler=self.txt2img_pipe.scheduler,
914
- safety_checker=None,
915
- requires_safety_checker=False
916
- )
917
-
918
- self._optimize_pipeline(self.img2img_pipe)
919
-
920
- self.model_loaded = True
921
- logger.info("Models loaded successfully!")
922
- return True
923
-
924
- except Exception as e:
925
- logger.error(f"Failed to load models: {e}")
926
- self.model_loaded = False
927
- return False
928
-
929
- def _optimize_pipeline(self, pipeline):
930
- """Apply memory optimizations to pipeline"""
931
- pipeline.enable_attention_slicing()
932
- pipeline.enable_vae_slicing()
933
-
934
- if DEVICE == "cuda":
935
- # Use sequential CPU offloading for better memory management
936
- pipeline.enable_sequential_cpu_offload()
937
- # Enable memory efficient attention if available
938
- try:
939
- pipeline.enable_xformers_memory_efficient_attention()
940
- except:
941
- logger.info("xformers not available, using default attention")
942
- else:
943
- pipeline = pipeline.to(DEVICE)
944
-
945
- # Global pipeline manager
946
- pipe_manager = PipelineManager()
947
-
948
- # Enhanced prompt templates
949
- QUALITY_TAGS = "score_9, score_8_up, score_7_up, masterpiece, best quality, ultra detailed, 8k"
950
-
951
- DEFAULT_NEGATIVE = """(worst quality:1.4), (low quality:1.4), (normal quality:1.2),
952
- lowres, bad anatomy, bad hands, signature, watermarks, ugly, imperfect eyes,
953
- skewed eyes, unnatural face, unnatural body, error, extra limb, missing limbs,
954
- painting by bad-artist, 3d, render"""
955
-
956
- EXAMPLE_PROMPTS = [
957
- "beautiful anime girl with long flowing silver hair, sakura petals, soft morning light",
958
- "cyberpunk street scene, neon lights reflecting on wet pavement, futuristic cityscape",
959
- "majestic dragon soaring through storm clouds, lightning, epic fantasy scene",
960
- "cute anthropomorphic fox girl, fluffy tail, forest clearing, magical sparkles",
961
- "elegant Victorian lady in ornate dress, portrait, vintage photography style",
962
- "futuristic mech suit, glowing energy core, sci-fi laboratory background",
963
- "mystical unicorn with rainbow mane, enchanted forest, ethereal atmosphere",
964
- "steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting"
965
- ]
966
-
967
  def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
968
  """Smart prompt enhancement"""
969
  if not prompt.strip():
 
245
  "steampunk inventor's workshop, brass gears, mechanical contraptions, warm lighting"
246
  ]
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
249
  """Smart prompt enhancement"""
250
  if not prompt.strip():