import gradio as gr import torch from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline from PIL import Image import os import gc import time from typing import Optional, Tuple from huggingface_hub import hf_hub_download import requests # Global pipeline variables txt2img_pipe = None img2img_pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" # Hugging Face model configuration MODEL_REPO = "ajsbsd/CyberRealistic-Pony" MODEL_FILENAME = "cyberrealisticPony_v110.safetensors" LOCAL_MODEL_PATH = "./models/cyberrealisticPony_v110.safetensors" def clear_memory(): """Clear GPU memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def download_model(): """Download model from Hugging Face if not already cached""" try: # Create models directory if it doesn't exist os.makedirs("./models", exist_ok=True) # Check if model already exists locally if os.path.exists(LOCAL_MODEL_PATH): print(f"Model already exists at {LOCAL_MODEL_PATH}") return LOCAL_MODEL_PATH print(f"Downloading model from {MODEL_REPO}/{MODEL_FILENAME}...") print("This may take a while on first run...") # Download the model file model_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILENAME, local_dir="./models", local_dir_use_symlinks=False, resume_download=True ) print(f"Model downloaded successfully to {model_path}") return model_path except Exception as e: print(f"Error downloading model: {e}") print("Attempting to use cached version or fallback...") # Try to use Hugging Face cache directly try: cached_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILENAME, resume_download=True ) print(f"Using cached model at {cached_path}") return cached_path except Exception as cache_error: print(f"Cache fallback failed: {cache_error}") return None def load_models(): """Load both text2img and img2img pipelines with Hugging Face integration""" global txt2img_pipe, img2img_pipe # Download model if needed model_path = download_model() if model_path is None: print("Failed to download or locate model file") return None, None if not os.path.exists(model_path): print(f"Model file not found after download: {model_path}") return None, None if txt2img_pipe is None: try: print("Loading CyberRealistic Pony Text2Img model...") txt2img_pipe = StableDiffusionXLPipeline.from_single_file( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if device == "cuda" else None ) # Memory optimizations txt2img_pipe.enable_attention_slicing() if device == "cuda": try: txt2img_pipe.enable_model_cpu_offload() print("Text2Img CPU offload enabled") except Exception as e: print(f"Text2Img CPU offload failed: {e}") txt2img_pipe = txt2img_pipe.to(device) else: txt2img_pipe = txt2img_pipe.to(device) print("Text2Img model loaded successfully!") except Exception as e: print(f"Error loading Text2Img model: {e}") return None, None if img2img_pipe is None: try: print("Loading CyberRealistic Pony Img2Img model...") img2img_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if device == "cuda" else None ) # Memory optimizations img2img_pipe.enable_attention_slicing() if device == "cuda": try: img2img_pipe.enable_model_cpu_offload() print("Img2Img CPU offload enabled") except Exception as e: print(f"Img2Img CPU offload failed: {e}") img2img_pipe = img2img_pipe.to(device) else: img2img_pipe = img2img_pipe.to(device) print("Img2Img model loaded successfully!") except Exception as e: print(f"Error loading Img2Img model: {e}") return txt2img_pipe, None return txt2img_pipe, img2img_pipe def enhance_prompt(prompt: str, add_quality_tags: bool = True) -> str: """Enhance prompt with Pony-style tags""" if not prompt.strip(): return prompt # Don't add tags if already present if prompt.startswith("score_") or not add_quality_tags: return prompt quality_tags = "score_9, score_8_up, score_7_up, masterpiece, best quality, highly detailed" return f"{quality_tags}, {prompt}" def validate_dimensions(width: int, height: int) -> Tuple[int, int]: """Ensure dimensions are valid for SDXL""" # SDXL works best with dimensions divisible by 64 width = ((width + 63) // 64) * 64 height = ((height + 63) // 64) * 64 # Ensure reasonable limits width = max(512, min(1536, width)) height = max(512, min(1536, height)) return width, height def generate_txt2img(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, add_quality_tags): """Generate image from text prompt with enhanced error handling""" global txt2img_pipe if not prompt.strip(): return None, "Please enter a prompt" # Load models if not already loaded if txt2img_pipe is None: txt2img_pipe, _ = load_models() if txt2img_pipe is None: return None, "Failed to load Text2Img model. Please check your internet connection and try again." try: # Clear memory before generation clear_memory() # Validate and fix dimensions width, height = validate_dimensions(width, height) # Set seed for reproducibility generator = None if seed != -1: generator = torch.Generator(device=device).manual_seed(int(seed)) # Enhance prompt enhanced_prompt = enhance_prompt(prompt, add_quality_tags) print(f"Generating with prompt: {enhanced_prompt[:100]}...") start_time = time.time() # Generate image with torch.no_grad(): result = txt2img_pipe( prompt=enhanced_prompt, negative_prompt=negative_prompt or "", num_inference_steps=int(num_steps), guidance_scale=float(guidance_scale), width=width, height=height, generator=generator ) generation_time = time.time() - start_time status = f"Text2Img: Generated successfully in {generation_time:.1f}s (Size: {width}x{height})" return result.images[0], status except Exception as e: error_msg = f"Text2Img generation failed: {str(e)}" print(error_msg) return None, error_msg finally: clear_memory() def generate_img2img(input_image, prompt, negative_prompt, num_steps, guidance_scale, strength, seed, add_quality_tags): """Generate image from input image + text prompt with enhanced error handling""" global img2img_pipe if input_image is None: return None, "Please upload an input image for Img2Img" if not prompt.strip(): return None, "Please enter a prompt" # Load models if not already loaded if img2img_pipe is None: _, img2img_pipe = load_models() if img2img_pipe is None: return None, "Failed to load Img2Img model. Please check your internet connection and try again." try: # Clear memory before generation clear_memory() # Set seed for reproducibility generator = None if seed != -1: generator = torch.Generator(device=device).manual_seed(int(seed)) # Enhance prompt enhanced_prompt = enhance_prompt(prompt, add_quality_tags) # Process input image if isinstance(input_image, Image.Image): # Ensure RGB format if input_image.mode != 'RGB': input_image = input_image.convert('RGB') # Resize to reasonable dimensions while maintaining aspect ratio original_size = input_image.size max_size = 1024 input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Ensure dimensions are divisible by 64 w, h = input_image.size w, h = validate_dimensions(w, h) input_image = input_image.resize((w, h), Image.Resampling.LANCZOS) print(f"Generating with prompt: {enhanced_prompt[:100]}...") start_time = time.time() # Generate image with torch.no_grad(): result = img2img_pipe( prompt=enhanced_prompt, negative_prompt=negative_prompt or "", image=input_image, num_inference_steps=int(num_steps), guidance_scale=float(guidance_scale), strength=float(strength), generator=generator ) generation_time = time.time() - start_time status = f"Img2Img: Generated successfully in {generation_time:.1f}s (Strength: {strength})" return result.images[0], status except Exception as e: error_msg = f"Img2Img generation failed: {str(e)}" print(error_msg) return None, error_msg finally: clear_memory() # Default negative prompt (improved) DEFAULT_NEGATIVE = """ (low quality:1.4), (worst quality:1.4), (bad quality:1.3), (normal quality:1.2), lowres, jpeg artifacts, blurry, noisy, ugly, deformed, disfigured, malformed, poorly drawn, bad art, amateur, render, 3D, cgi, (text, signature, watermark, username, copyright:1.5), (extra limbs:1.5), (missing limbs:1.5), (extra fingers:1.5), (missing fingers:1.5), (mutated hands:1.5), (bad hands:1.4), (poorly drawn hands:1.3), (ugly hands:1.2), (bad anatomy:1.4), (deformed body:1.3), (unnatural body:1.2), (cross-eyed:1.3), (skewed eyes:1.3), (imperfect eyes:1.2), (ugly eyes:1.2), (asymmetrical face:1.2), (unnatural face:1.2), (blush:1.1), (shadow on skin:1.1), (shaded skin:1.1), (dark skin:1.1), abstract, simplified, unrealistic, impressionistic, cartoon, anime, drawing, sketch, illustration, painting, censored, grayscale, monochrome, out of frame, cropped, distorted. """ # Create Gradio interface with enhanced styling with gr.Blocks( title="CyberRealistic Pony Image Generator", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .tab-nav button { font-size: 16px; font-weight: bold; } """ ) as demo: gr.Markdown(""" # 🎨 CyberRealistic Pony Image Generator (Hugging Face Edition) Generate high-quality images using the CyberRealistic Pony SDXL model from Hugging Face. **Features:** - 🎨 Text-to-Image generation - 🖼️ Image-to-Image transformation - 🎯 Automatic quality tag enhancement - ⚡ Memory optimized for better performance - 🤗 Auto-downloads model from Hugging Face **Note:** On first run, the model will be downloaded from Hugging Face (this may take a few minutes). """) with gr.Tabs(): # Text2Image Tab with gr.TabItem("🎨 Text to Image"): with gr.Row(): with gr.Column(scale=1): # Input controls for Text2Img txt2img_prompt = gr.Textbox( label="Prompt", placeholder="Enter your image description...", value="beautiful landscape with mountains and lake at sunset", lines=3 ) txt2img_negative = gr.Textbox( label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=3 ) txt2img_quality_tags = gr.Checkbox( label="Add Quality Tags", value=True ) with gr.Row(): txt2img_steps = gr.Slider( minimum=10, maximum=50, value=25, step=1, label="Inference Steps" ) txt2img_guidance = gr.Slider( minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale" ) with gr.Row(): txt2img_width = gr.Slider( minimum=512, maximum=1536, value=1024, step=64, label="Width" ) txt2img_height = gr.Slider( minimum=512, maximum=1536, value=1024, step=64, label="Height" ) txt2img_seed = gr.Number( label="Seed (-1 for random)", value=-1, precision=0 ) txt2img_btn = gr.Button("🎨 Generate Image", variant="primary") with gr.Column(scale=2): # Output for Text2Img txt2img_output = gr.Image( label="Generated Image", type="pil", height=600 ) txt2img_status = gr.Textbox(label="Status", interactive=False) # Image2Image Tab with gr.TabItem("🖼️ Image to Image"): with gr.Row(): with gr.Column(scale=1): # Input controls for Img2Img img2img_input = gr.Image( label="Input Image", type="pil", height=300 ) img2img_prompt = gr.Textbox( label="Prompt", placeholder="Describe how to modify the image...", value="in the style of a digital painting, vibrant colors", lines=3 ) img2img_negative = gr.Textbox( label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=3 ) img2img_quality_tags = gr.Checkbox( label="Add Quality Tags", value=True ) with gr.Row(): img2img_steps = gr.Slider( minimum=10, maximum=50, value=25, step=1, label="Inference Steps" ) img2img_guidance = gr.Slider( minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale" ) img2img_strength = gr.Slider( minimum=0.1, maximum=1.0, value=0.75, step=0.05, label="Denoising Strength (Lower = more like input, Higher = more creative)" ) img2img_seed = gr.Number( label="Seed (-1 for random)", value=-1, precision=0 ) img2img_btn = gr.Button("🖼️ Transform Image", variant="primary") with gr.Column(scale=2): # Output for Img2Img img2img_output = gr.Image( label="Generated Image", type="pil", height=600 ) img2img_status = gr.Textbox(label="Status", interactive=False) # Event handlers txt2img_btn.click( fn=generate_txt2img, inputs=[txt2img_prompt, txt2img_negative, txt2img_steps, txt2img_guidance, txt2img_width, txt2img_height, txt2img_seed, txt2img_quality_tags], outputs=[txt2img_output, txt2img_status] ) img2img_btn.click( fn=generate_img2img, inputs=[img2img_input, img2img_prompt, img2img_negative, txt2img_steps, img2img_guidance, img2img_strength, img2img_seed, img2img_quality_tags], outputs=[img2img_output, img2img_status] ) # Load models on startup print("Initializing CyberRealistic Pony Generator (Hugging Face Edition)...") print(f"Device: {device}") print(f"Model Repository: {MODEL_REPO}") print(f"Model File: {MODEL_FILENAME}") # Pre-load models in a separate thread to avoid blocking startup import threading def preload_models(): """Pre-load models in background""" try: print("Starting background model loading...") load_models() print("Background model loading completed!") except Exception as e: print(f"Background model loading failed: {e}") # Start background loading loading_thread = threading.Thread(target=preload_models, daemon=True) loading_thread.start() if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )