Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -39,14 +39,14 @@ class PipelineManager:
|
|
39 |
self.nsfw_detector_loaded = False
|
40 |
|
41 |
def clear_memory(self):
|
42 |
-
"""Aggressive memory cleanup"""
|
43 |
if torch.cuda.is_available():
|
44 |
torch.cuda.empty_cache()
|
45 |
torch.cuda.synchronize()
|
46 |
gc.collect()
|
47 |
|
48 |
def load_nsfw_detector(self) -> bool:
|
49 |
-
"""Load NSFW detection model"""
|
50 |
if self.nsfw_detector_loaded:
|
51 |
return True
|
52 |
|
@@ -70,12 +70,13 @@ class PipelineManager:
|
|
70 |
def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
|
71 |
"""
|
72 |
Detects NSFW content using CLIP-based zero-shot classification.
|
73 |
-
Falls back to prompt-based detection if CLIP model fails.
|
74 |
"""
|
75 |
try:
|
76 |
# Load NSFW detector if not already loaded
|
77 |
if not self.nsfw_detector_loaded:
|
78 |
if not self.load_nsfw_detector():
|
|
|
79 |
return self._fallback_nsfw_detection(prompt)
|
80 |
|
81 |
# CLIP-based NSFW detection
|
@@ -132,11 +133,12 @@ class PipelineManager:
|
|
132 |
return is_nsfw_result, confidence
|
133 |
|
134 |
except Exception as e:
|
135 |
-
logger.error(f"NSFW detection error: {e}")
|
|
|
136 |
return self._fallback_nsfw_detection(prompt)
|
137 |
|
138 |
def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
|
139 |
-
"""Fallback NSFW detection based on prompt analysis"""
|
140 |
nsfw_keywords = [
|
141 |
'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
|
142 |
'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
|
@@ -148,13 +150,15 @@ class PipelineManager:
|
|
148 |
logger.warning(f"π¨ NSFW content detected (prompt-based: '{keyword}' found)")
|
149 |
return True, random.uniform(0.7, 0.95)
|
150 |
|
151 |
-
# Random chance for demonstration (
|
152 |
if random.random() < 0.02: # 2% chance for demo
|
153 |
logger.warning("π¨ NSFW content detected (random demo detection)")
|
154 |
return True, random.uniform(0.6, 0.8)
|
155 |
|
156 |
return False, random.uniform(0.1, 0.3)
|
157 |
-
|
|
|
|
|
158 |
if self.model_loaded:
|
159 |
return True
|
160 |
|
@@ -176,14 +180,14 @@ class PipelineManager:
|
|
176 |
torch_dtype=DTYPE,
|
177 |
use_safetensors=True,
|
178 |
variant="fp16" if DEVICE == "cuda" else None,
|
179 |
-
safety_checker=None, # Disable for faster loading
|
180 |
requires_safety_checker=False
|
181 |
)
|
182 |
|
183 |
-
#
|
184 |
self._optimize_pipeline(self.txt2img_pipe)
|
185 |
|
186 |
-
# Create img2img pipeline sharing components
|
187 |
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
|
188 |
vae=self.txt2img_pipe.vae,
|
189 |
text_encoder=self.txt2img_pipe.text_encoder,
|
@@ -192,10 +196,11 @@ class PipelineManager:
|
|
192 |
tokenizer_2=self.txt2img_pipe.tokenizer_2,
|
193 |
unet=self.txt2img_pipe.unet,
|
194 |
scheduler=self.txt2img_pipe.scheduler,
|
195 |
-
safety_checker=None,
|
196 |
requires_safety_checker=False
|
197 |
)
|
198 |
|
|
|
199 |
self._optimize_pipeline(self.img2img_pipe)
|
200 |
|
201 |
self.model_loaded = True
|
@@ -208,22 +213,23 @@ class PipelineManager:
|
|
208 |
return False
|
209 |
|
210 |
def _optimize_pipeline(self, pipeline):
|
211 |
-
"""Apply memory optimizations to pipeline"""
|
212 |
pipeline.enable_attention_slicing()
|
213 |
pipeline.enable_vae_slicing()
|
214 |
|
215 |
if DEVICE == "cuda":
|
216 |
-
# Use sequential CPU offloading for better memory management
|
217 |
pipeline.enable_sequential_cpu_offload()
|
218 |
-
# Enable memory efficient attention if available
|
219 |
try:
|
220 |
pipeline.enable_xformers_memory_efficient_attention()
|
221 |
-
except:
|
222 |
logger.info("xformers not available, using default attention")
|
223 |
else:
|
|
|
224 |
pipeline = pipeline.to(DEVICE)
|
225 |
|
226 |
-
# Global pipeline manager
|
227 |
pipe_manager = PipelineManager()
|
228 |
|
229 |
# Enhanced prompt templates
|
@@ -246,11 +252,13 @@ EXAMPLE_PROMPTS = [
|
|
246 |
]
|
247 |
|
248 |
def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
|
249 |
-
"""
|
|
|
|
|
250 |
if not prompt.strip():
|
251 |
return ""
|
252 |
|
253 |
-
# Don't add quality tags if they're already present
|
254 |
if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
|
255 |
return prompt
|
256 |
|
@@ -259,22 +267,28 @@ def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
|
|
259 |
return prompt
|
260 |
|
261 |
def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
|
262 |
-
"""
|
|
|
|
|
|
|
263 |
# Round to nearest multiple of 64
|
264 |
width = max(512, min(1024, ((width + 31) // 64) * 64))
|
265 |
height = max(512, min(1024, ((height + 31) // 64) * 64))
|
266 |
|
267 |
# Ensure reasonable aspect ratios (prevent extremely wide/tall images)
|
268 |
aspect_ratio = width / height
|
269 |
-
if aspect_ratio > 2.0: # Too wide
|
270 |
height = width // 2
|
271 |
-
elif aspect_ratio < 0.5: # Too tall
|
272 |
width = height // 2
|
273 |
|
274 |
return width, height
|
275 |
|
276 |
def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
|
277 |
-
"""
|
|
|
|
|
|
|
278 |
temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
|
279 |
|
280 |
meta = PngImagePlugin.PngInfo()
|
@@ -282,7 +296,7 @@ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
|
|
282 |
if value is not None:
|
283 |
meta.add_text(key, str(value))
|
284 |
|
285 |
-
# Add generation timestamp
|
286 |
meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
|
287 |
meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
|
288 |
|
@@ -290,7 +304,9 @@ def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
|
|
290 |
return temp_path
|
291 |
|
292 |
def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
|
293 |
-
"""
|
|
|
|
|
294 |
info_lines = [
|
295 |
f"β
Generated in {generation_time:.1f}s",
|
296 |
f"π Resolution: {params.get('width', 'N/A')}Γ{params.get('height', 'N/A')}",
|
@@ -305,20 +321,23 @@ def format_generation_info(params: Dict[str, Any], generation_time: float) -> st
|
|
305 |
|
306 |
return "\n".join(info_lines)
|
307 |
|
308 |
-
@spaces.GPU(duration=120) # Increased duration for model loading
|
309 |
def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
|
310 |
width: int, height: int, seed: int, add_quality: bool) -> Tuple:
|
311 |
-
"""
|
|
|
|
|
|
|
312 |
|
313 |
if not prompt.strip():
|
314 |
-
return None, None, "β Please enter a prompt"
|
315 |
|
316 |
-
# Lazy load models
|
317 |
if not pipe_manager.load_models():
|
318 |
return None, None, "β Failed to load model. Please try again."
|
319 |
|
320 |
try:
|
321 |
-
pipe_manager.clear_memory()
|
322 |
|
323 |
# Process parameters
|
324 |
width, height = validate_and_fix_dimensions(width, height)
|
@@ -328,12 +347,12 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
|
|
328 |
enhanced_prompt = enhance_prompt(prompt, add_quality)
|
329 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
330 |
|
331 |
-
# Generation parameters
|
332 |
gen_params = {
|
333 |
"prompt": enhanced_prompt,
|
334 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
335 |
-
"num_inference_steps": min(max(steps, 10), 50), # Clamp steps
|
336 |
-
"guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance
|
337 |
"width": width,
|
338 |
"height": height,
|
339 |
"generator": generator,
|
@@ -348,11 +367,11 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
|
|
348 |
|
349 |
generation_time = time.time() - start_time
|
350 |
|
351 |
-
# NSFW Detection
|
352 |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
|
353 |
|
354 |
if is_nsfw_result:
|
355 |
-
#
|
356 |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
|
357 |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
|
358 |
|
@@ -376,7 +395,7 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
|
|
376 |
|
377 |
return blurred_image, png_path, info_text
|
378 |
|
379 |
-
#
|
380 |
metadata = {
|
381 |
"prompt": enhanced_prompt,
|
382 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
@@ -402,38 +421,42 @@ def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_sca
|
|
402 |
logger.error(f"Generation error: {e}")
|
403 |
return None, None, f"β Generation failed: {str(e)}"
|
404 |
finally:
|
405 |
-
pipe_manager.clear_memory()
|
406 |
|
407 |
-
@spaces.GPU(duration=120)
|
408 |
def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
|
409 |
steps: int, guidance_scale: float, strength: float, seed: int,
|
410 |
add_quality: bool) -> Tuple:
|
411 |
-
"""
|
|
|
|
|
|
|
412 |
|
413 |
if input_image is None:
|
414 |
-
return None, None, "β Please upload an input image"
|
415 |
|
416 |
if not prompt.strip():
|
417 |
-
return None, None, "β Please enter a prompt"
|
418 |
|
|
|
419 |
if not pipe_manager.load_models():
|
420 |
return None, None, "β Failed to load model. Please try again."
|
421 |
|
422 |
try:
|
423 |
-
pipe_manager.clear_memory()
|
424 |
|
425 |
-
# Process input image
|
426 |
if input_image.mode != 'RGB':
|
427 |
input_image = input_image.convert('RGB')
|
428 |
|
429 |
-
# Smart resizing maintaining aspect ratio
|
430 |
original_size = input_image.size
|
431 |
max_dimension = 1024
|
432 |
|
433 |
if max(original_size) > max_dimension:
|
434 |
input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
|
435 |
|
436 |
-
# Ensure SDXL compatible dimensions
|
437 |
w, h = validate_and_fix_dimensions(*input_image.size)
|
438 |
input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
|
439 |
|
@@ -444,14 +467,14 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
|
|
444 |
enhanced_prompt = enhance_prompt(prompt, add_quality)
|
445 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
446 |
|
447 |
-
# Generation parameters
|
448 |
gen_params = {
|
449 |
"prompt": enhanced_prompt,
|
450 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
451 |
"image": input_image,
|
452 |
-
"num_inference_steps": min(max(steps, 10), 50),
|
453 |
-
"guidance_scale": max(1.0, min(guidance_scale, 20.0)),
|
454 |
-
"strength": max(0.1, min(strength, 1.0)),
|
455 |
"generator": generator,
|
456 |
"output_type": "pil"
|
457 |
}
|
@@ -464,11 +487,11 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
|
|
464 |
|
465 |
generation_time = time.time() - start_time
|
466 |
|
467 |
-
# NSFW Detection
|
468 |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
|
469 |
|
470 |
if is_nsfw_result:
|
471 |
-
#
|
472 |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
|
473 |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
|
474 |
|
@@ -492,7 +515,7 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
|
|
492 |
|
493 |
return blurred_image, png_path, info_text
|
494 |
|
495 |
-
#
|
496 |
metadata = {
|
497 |
"prompt": enhanced_prompt,
|
498 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
@@ -518,15 +541,18 @@ def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str
|
|
518 |
logger.error(f"Generation error: {e}")
|
519 |
return None, None, f"β Generation failed: {str(e)}"
|
520 |
finally:
|
521 |
-
pipe_manager.clear_memory()
|
522 |
|
523 |
def get_random_prompt():
|
524 |
-
"""
|
525 |
return random.choice(EXAMPLE_PROMPTS)
|
526 |
|
527 |
# Enhanced Gradio interface
|
528 |
def create_interface():
|
529 |
-
"""
|
|
|
|
|
|
|
530 |
|
531 |
with gr.Blocks(
|
532 |
title="CyberRealistic Pony - SDXL Generator",
|
@@ -729,42 +755,17 @@ def create_interface():
|
|
729 |
|
730 |
return demo
|
731 |
|
732 |
-
# Initialize and launch
|
733 |
if __name__ == "__main__":
|
734 |
logger.info(f"π Initializing CyberRealistic Pony Generator on {DEVICE}")
|
735 |
logger.info(f"π± PyTorch version: {torch.__version__}")
|
736 |
logger.info(f"π‘οΈ NSFW Content Filter: Enabled")
|
737 |
|
738 |
demo = create_interface()
|
739 |
-
demo.queue(max_size=20) # Enable queuing for better
|
740 |
demo.launch(
|
741 |
server_name="0.0.0.0",
|
742 |
server_port=7860,
|
743 |
show_error=True,
|
744 |
-
share=False # Set to True if you want a public link
|
745 |
)
|
746 |
-
|
747 |
-
# Example prompt buttons
|
748 |
-
txt_example_btn.click(fn=get_random_prompt, outputs=[txt_prompt])
|
749 |
-
img_example_btn.click(fn=get_random_prompt, outputs=[img_prompt])
|
750 |
-
|
751 |
-
# Clear buttons
|
752 |
-
txt_clear_btn.click(lambda: "", outputs=[txt_prompt])
|
753 |
-
img_clear_btn.click(lambda: "", outputs=[img_prompt])
|
754 |
-
|
755 |
-
return demo
|
756 |
-
|
757 |
-
# Initialize and launch
|
758 |
-
if __name__ == "__main__":
|
759 |
-
logger.info(f"π Initializing CyberRealistic Pony Generator on {DEVICE}")
|
760 |
-
logger.info(f"π± PyTorch version: {torch.__version__}")
|
761 |
-
logger.info(f"π‘οΈ NSFW Content Filter: Enabled")
|
762 |
-
|
763 |
-
demo = create_interface()
|
764 |
-
demo.queue(max_size=20) # Enable queuing for better UX
|
765 |
-
demo.launch(
|
766 |
-
server_name="0.0.0.0",
|
767 |
-
server_port=7860,
|
768 |
-
show_error=True,
|
769 |
-
share=False # Set to True if you want a public link
|
770 |
-
)
|
|
|
39 |
self.nsfw_detector_loaded = False
|
40 |
|
41 |
def clear_memory(self):
|
42 |
+
"""Aggressive memory cleanup to free up GPU/CPU memory."""
|
43 |
if torch.cuda.is_available():
|
44 |
torch.cuda.empty_cache()
|
45 |
torch.cuda.synchronize()
|
46 |
gc.collect()
|
47 |
|
48 |
def load_nsfw_detector(self) -> bool:
|
49 |
+
"""Load NSFW detection model (CLIP) with error handling."""
|
50 |
if self.nsfw_detector_loaded:
|
51 |
return True
|
52 |
|
|
|
70 |
def is_nsfw(self, image: Image.Image, prompt: str = "") -> Tuple[bool, float]:
|
71 |
"""
|
72 |
Detects NSFW content using CLIP-based zero-shot classification.
|
73 |
+
Falls back to prompt-based detection if CLIP model fails or is not loaded.
|
74 |
"""
|
75 |
try:
|
76 |
# Load NSFW detector if not already loaded
|
77 |
if not self.nsfw_detector_loaded:
|
78 |
if not self.load_nsfw_detector():
|
79 |
+
# If NSFW detector cannot be loaded, fall back to prompt-based
|
80 |
return self._fallback_nsfw_detection(prompt)
|
81 |
|
82 |
# CLIP-based NSFW detection
|
|
|
133 |
return is_nsfw_result, confidence
|
134 |
|
135 |
except Exception as e:
|
136 |
+
logger.error(f"NSFW detection error (CLIP model failed): {e}")
|
137 |
+
# Fallback to prompt-based detection if CLIP model encounters an error
|
138 |
return self._fallback_nsfw_detection(prompt)
|
139 |
|
140 |
def _fallback_nsfw_detection(self, prompt: str = "") -> Tuple[bool, float]:
|
141 |
+
"""Fallback NSFW detection based on prompt keyword analysis."""
|
142 |
nsfw_keywords = [
|
143 |
'nude', 'naked', 'nsfw', 'explicit', 'sexual', 'erotic', 'porn',
|
144 |
'adult', 'xxx', 'sex', 'breast', 'nipple', 'genital', 'provocative'
|
|
|
150 |
logger.warning(f"π¨ NSFW content detected (prompt-based: '{keyword}' found)")
|
151 |
return True, random.uniform(0.7, 0.95)
|
152 |
|
153 |
+
# Random chance for demonstration (consider removing in production)
|
154 |
if random.random() < 0.02: # 2% chance for demo
|
155 |
logger.warning("π¨ NSFW content detected (random demo detection)")
|
156 |
return True, random.uniform(0.6, 0.8)
|
157 |
|
158 |
return False, random.uniform(0.1, 0.3)
|
159 |
+
|
160 |
+
def load_models(self) -> bool:
|
161 |
+
"""Load Stable Diffusion XL models (txt2img and img2img) with enhanced error handling and memory optimization."""
|
162 |
if self.model_loaded:
|
163 |
return True
|
164 |
|
|
|
180 |
torch_dtype=DTYPE,
|
181 |
use_safetensors=True,
|
182 |
variant="fp16" if DEVICE == "cuda" else None,
|
183 |
+
safety_checker=None, # Disable for faster loading, using custom NSFW check
|
184 |
requires_safety_checker=False
|
185 |
)
|
186 |
|
187 |
+
# Apply memory optimizations to txt2img pipeline
|
188 |
self._optimize_pipeline(self.txt2img_pipe)
|
189 |
|
190 |
+
# Create img2img pipeline sharing components to save memory
|
191 |
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
|
192 |
vae=self.txt2img_pipe.vae,
|
193 |
text_encoder=self.txt2img_pipe.text_encoder,
|
|
|
196 |
tokenizer_2=self.txt2img_pipe.tokenizer_2,
|
197 |
unet=self.txt2img_pipe.unet,
|
198 |
scheduler=self.txt2img_pipe.scheduler,
|
199 |
+
safety_checker=None, # Disable for faster loading, using custom NSFW check
|
200 |
requires_safety_checker=False
|
201 |
)
|
202 |
|
203 |
+
# Apply memory optimizations to img2img pipeline
|
204 |
self._optimize_pipeline(self.img2img_pipe)
|
205 |
|
206 |
self.model_loaded = True
|
|
|
213 |
return False
|
214 |
|
215 |
def _optimize_pipeline(self, pipeline):
|
216 |
+
"""Apply memory optimizations to a given diffusion pipeline."""
|
217 |
pipeline.enable_attention_slicing()
|
218 |
pipeline.enable_vae_slicing()
|
219 |
|
220 |
if DEVICE == "cuda":
|
221 |
+
# Use sequential CPU offloading for better memory management on GPU
|
222 |
pipeline.enable_sequential_cpu_offload()
|
223 |
+
# Enable memory efficient attention if xformers is available
|
224 |
try:
|
225 |
pipeline.enable_xformers_memory_efficient_attention()
|
226 |
+
except Exception: # Catch any error if xformers is not installed/configured
|
227 |
logger.info("xformers not available, using default attention")
|
228 |
else:
|
229 |
+
# Move pipeline to CPU if CUDA is not available
|
230 |
pipeline = pipeline.to(DEVICE)
|
231 |
|
232 |
+
# Global pipeline manager instance
|
233 |
pipe_manager = PipelineManager()
|
234 |
|
235 |
# Enhanced prompt templates
|
|
|
252 |
]
|
253 |
|
254 |
def enhance_prompt(prompt: str, add_quality: bool = True) -> str:
|
255 |
+
"""
|
256 |
+
Enhances the given prompt with quality tags unless they are already present.
|
257 |
+
"""
|
258 |
if not prompt.strip():
|
259 |
return ""
|
260 |
|
261 |
+
# Don't add quality tags if they're already present in the prompt (case-insensitive)
|
262 |
if any(tag in prompt.lower() for tag in ["score_", "masterpiece", "best quality"]):
|
263 |
return prompt
|
264 |
|
|
|
267 |
return prompt
|
268 |
|
269 |
def validate_and_fix_dimensions(width: int, height: int) -> Tuple[int, int]:
|
270 |
+
"""
|
271 |
+
Ensures SDXL-compatible dimensions (multiples of 64) and reasonable aspect ratios.
|
272 |
+
Clamps dimensions between 512 and 1024.
|
273 |
+
"""
|
274 |
# Round to nearest multiple of 64
|
275 |
width = max(512, min(1024, ((width + 31) // 64) * 64))
|
276 |
height = max(512, min(1024, ((height + 31) // 64) * 64))
|
277 |
|
278 |
# Ensure reasonable aspect ratios (prevent extremely wide/tall images)
|
279 |
aspect_ratio = width / height
|
280 |
+
if aspect_ratio > 2.0: # Too wide, adjust height
|
281 |
height = width // 2
|
282 |
+
elif aspect_ratio < 0.5: # Too tall, adjust width
|
283 |
width = height // 2
|
284 |
|
285 |
return width, height
|
286 |
|
287 |
def create_metadata_png(image: Image.Image, params: Dict[str, Any]) -> str:
|
288 |
+
"""
|
289 |
+
Creates a temporary PNG file with embedded metadata from the generation parameters.
|
290 |
+
Returns the path to the created PNG file.
|
291 |
+
"""
|
292 |
temp_path = tempfile.mktemp(suffix=".png", prefix="cyberrealistic_")
|
293 |
|
294 |
meta = PngImagePlugin.PngInfo()
|
|
|
296 |
if value is not None:
|
297 |
meta.add_text(key, str(value))
|
298 |
|
299 |
+
# Add generation timestamp and model info
|
300 |
meta.add_text("Generated", datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC"))
|
301 |
meta.add_text("Model", f"{MODEL_REPO}/{MODEL_FILENAME}")
|
302 |
|
|
|
304 |
return temp_path
|
305 |
|
306 |
def format_generation_info(params: Dict[str, Any], generation_time: float) -> str:
|
307 |
+
"""
|
308 |
+
Formats the generation information into a human-readable string for display.
|
309 |
+
"""
|
310 |
info_lines = [
|
311 |
f"β
Generated in {generation_time:.1f}s",
|
312 |
f"π Resolution: {params.get('width', 'N/A')}Γ{params.get('height', 'N/A')}",
|
|
|
321 |
|
322 |
return "\n".join(info_lines)
|
323 |
|
324 |
+
@spaces.GPU(duration=120) # Increased duration for model loading and generation
|
325 |
def generate_txt2img(prompt: str, negative_prompt: str, steps: int, guidance_scale: float,
|
326 |
width: int, height: int, seed: int, add_quality: bool) -> Tuple:
|
327 |
+
"""
|
328 |
+
Handles text-to-image generation, including parameter processing, model inference,
|
329 |
+
NSFW detection, and metadata creation.
|
330 |
+
"""
|
331 |
|
332 |
if not prompt.strip():
|
333 |
+
return None, None, "β Please enter a prompt."
|
334 |
|
335 |
+
# Lazy load models if not already loaded
|
336 |
if not pipe_manager.load_models():
|
337 |
return None, None, "β Failed to load model. Please try again."
|
338 |
|
339 |
try:
|
340 |
+
pipe_manager.clear_memory() # Clear memory before generation
|
341 |
|
342 |
# Process parameters
|
343 |
width, height = validate_and_fix_dimensions(width, height)
|
|
|
347 |
enhanced_prompt = enhance_prompt(prompt, add_quality)
|
348 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
349 |
|
350 |
+
# Generation parameters dictionary
|
351 |
gen_params = {
|
352 |
"prompt": enhanced_prompt,
|
353 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
354 |
+
"num_inference_steps": min(max(steps, 10), 50), # Clamp steps to a reasonable range
|
355 |
+
"guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale
|
356 |
"width": width,
|
357 |
"height": height,
|
358 |
"generator": generator,
|
|
|
367 |
|
368 |
generation_time = time.time() - start_time
|
369 |
|
370 |
+
# Perform NSFW Detection on the generated image
|
371 |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
|
372 |
|
373 |
if is_nsfw_result:
|
374 |
+
# If NSFW, blur the image and return a warning message
|
375 |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
|
376 |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
|
377 |
|
|
|
395 |
|
396 |
return blurred_image, png_path, info_text
|
397 |
|
398 |
+
# If not NSFW, prepare metadata and save the original image
|
399 |
metadata = {
|
400 |
"prompt": enhanced_prompt,
|
401 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
|
|
421 |
logger.error(f"Generation error: {e}")
|
422 |
return None, None, f"β Generation failed: {str(e)}"
|
423 |
finally:
|
424 |
+
pipe_manager.clear_memory() # Ensure memory is cleared even if an error occurs
|
425 |
|
426 |
+
@spaces.GPU(duration=120) # Increased duration for model loading and generation
|
427 |
def generate_img2img(input_image: Image.Image, prompt: str, negative_prompt: str,
|
428 |
steps: int, guidance_scale: float, strength: float, seed: int,
|
429 |
add_quality: bool) -> Tuple:
|
430 |
+
"""
|
431 |
+
Handles image-to-image generation, including image preprocessing, parameter processing,
|
432 |
+
model inference, NSFW detection, and metadata creation.
|
433 |
+
"""
|
434 |
|
435 |
if input_image is None:
|
436 |
+
return None, None, "β Please upload an input image."
|
437 |
|
438 |
if not prompt.strip():
|
439 |
+
return None, None, "β Please enter a prompt."
|
440 |
|
441 |
+
# Lazy load models if not already loaded
|
442 |
if not pipe_manager.load_models():
|
443 |
return None, None, "β Failed to load model. Please try again."
|
444 |
|
445 |
try:
|
446 |
+
pipe_manager.clear_memory() # Clear memory before generation
|
447 |
|
448 |
+
# Process input image: convert to RGB if necessary
|
449 |
if input_image.mode != 'RGB':
|
450 |
input_image = input_image.convert('RGB')
|
451 |
|
452 |
+
# Smart resizing maintaining aspect ratio to fit within max_dimension
|
453 |
original_size = input_image.size
|
454 |
max_dimension = 1024
|
455 |
|
456 |
if max(original_size) > max_dimension:
|
457 |
input_image.thumbnail((max_dimension, max_dimension), Image.Resampling.LANCZOS)
|
458 |
|
459 |
+
# Ensure SDXL compatible dimensions (multiples of 64)
|
460 |
w, h = validate_and_fix_dimensions(*input_image.size)
|
461 |
input_image = input_image.resize((w, h), Image.Resampling.LANCZOS)
|
462 |
|
|
|
467 |
enhanced_prompt = enhance_prompt(prompt, add_quality)
|
468 |
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
469 |
|
470 |
+
# Generation parameters dictionary
|
471 |
gen_params = {
|
472 |
"prompt": enhanced_prompt,
|
473 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
474 |
"image": input_image,
|
475 |
+
"num_inference_steps": min(max(steps, 10), 50), # Clamp steps
|
476 |
+
"guidance_scale": max(1.0, min(guidance_scale, 20.0)), # Clamp guidance scale
|
477 |
+
"strength": max(0.1, min(strength, 1.0)), # Clamp strength
|
478 |
"generator": generator,
|
479 |
"output_type": "pil"
|
480 |
}
|
|
|
487 |
|
488 |
generation_time = time.time() - start_time
|
489 |
|
490 |
+
# Perform NSFW Detection on the transformed image
|
491 |
is_nsfw_result, nsfw_confidence = pipe_manager.is_nsfw(result.images[0], enhanced_prompt)
|
492 |
|
493 |
if is_nsfw_result:
|
494 |
+
# If NSFW, blur the image and return a warning message
|
495 |
blurred_image = result.images[0].filter(ImageFilter.GaussianBlur(radius=20))
|
496 |
warning_msg = f"β οΈ Content flagged as potentially inappropriate (confidence: {nsfw_confidence:.2f}). Image has been blurred."
|
497 |
|
|
|
515 |
|
516 |
return blurred_image, png_path, info_text
|
517 |
|
518 |
+
# If not NSFW, prepare metadata and save the original image
|
519 |
metadata = {
|
520 |
"prompt": enhanced_prompt,
|
521 |
"negative_prompt": negative_prompt or DEFAULT_NEGATIVE,
|
|
|
541 |
logger.error(f"Generation error: {e}")
|
542 |
return None, None, f"β Generation failed: {str(e)}"
|
543 |
finally:
|
544 |
+
pipe_manager.clear_memory() # Ensure memory is cleared even if an error occurs
|
545 |
|
546 |
def get_random_prompt():
|
547 |
+
"""Returns a random example prompt from a predefined list."""
|
548 |
return random.choice(EXAMPLE_PROMPTS)
|
549 |
|
550 |
# Enhanced Gradio interface
|
551 |
def create_interface():
|
552 |
+
"""
|
553 |
+
Creates and returns the Gradio Blocks interface for the CyberRealistic Pony Generator.
|
554 |
+
This includes tabs for Text-to-Image and Image-to-Image, along with controls and outputs.
|
555 |
+
"""
|
556 |
|
557 |
with gr.Blocks(
|
558 |
title="CyberRealistic Pony - SDXL Generator",
|
|
|
755 |
|
756 |
return demo
|
757 |
|
758 |
+
# Initialize and launch the Gradio application
|
759 |
if __name__ == "__main__":
|
760 |
logger.info(f"π Initializing CyberRealistic Pony Generator on {DEVICE}")
|
761 |
logger.info(f"π± PyTorch version: {torch.__version__}")
|
762 |
logger.info(f"π‘οΈ NSFW Content Filter: Enabled")
|
763 |
|
764 |
demo = create_interface()
|
765 |
+
demo.queue(max_size=20) # Enable queuing for better user experience
|
766 |
demo.launch(
|
767 |
server_name="0.0.0.0",
|
768 |
server_port=7860,
|
769 |
show_error=True,
|
770 |
+
share=False # Set to True if you want a public link (e.g., for Hugging Face Spaces)
|
771 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|