Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| from diffusers import FluxPipeline | |
| from translatepy import Translator | |
| # ----------------------------------------------------------------------------- | |
| # CONFIGURATION | |
| # ----------------------------------------------------------------------------- | |
| class Config: | |
| MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| DEFAULT_LORA = "nftnik/BR_ohwx_V1" | |
| DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors" | |
| MAX_SEED = int(np.iinfo(np.int32).max) | |
| CSS = "footer { visibility: hidden; }" | |
| DEFAULT_WIDTH = 896 | |
| DEFAULT_HEIGHT = 1152 | |
| DEFAULT_GUIDANCE_SCALE = 3.5 | |
| DEFAULT_STEPS = 35 | |
| DEFAULT_LORA_SCALE = 1.0 | |
| DEFAULT_TRIGGER_WORD = "ohwx" | |
| # Memory optimization configs | |
| ENABLE_MEMORY_EFFICIENT_ATTENTION = True | |
| ENABLE_SEQUENTIAL_CPU_OFFLOAD = True | |
| ENABLE_ATTENTION_SLICING = "max" | |
| # ----------------------------------------------------------------------------- | |
| # FluxGenerator class to handle image generation | |
| # ----------------------------------------------------------------------------- | |
| class FluxGenerator: | |
| def __init__(self): | |
| # Environment setup | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| self.translator = Translator() | |
| self.device = self._get_optimal_device() | |
| print(f"Using {self.device.upper()}") | |
| # Initialize pipeline | |
| self.pipe = None | |
| self._initialize_pipeline() | |
| def _get_optimal_device(self): | |
| """Determine the optimal device based on available resources""" | |
| if torch.cuda.is_available(): | |
| # Check GPU memory | |
| try: | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory | |
| if gpu_memory > 10 * 1024 * 1024 * 1024: # More than 10GB | |
| return "cuda" | |
| else: | |
| print("Limited GPU memory detected, using CPU with GPU acceleration") | |
| return "cuda" # Still use CUDA but will apply memory optimizations | |
| except: | |
| print("Error checking GPU memory, falling back to CPU") | |
| return "cpu" | |
| else: | |
| return "cpu" | |
| def _initialize_pipeline(self): | |
| """Initialize the Flux pipeline with memory optimizations""" | |
| try: | |
| print("Loading Flux model...") | |
| # Use more memory-efficient settings | |
| pipe_kwargs = { | |
| "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32, | |
| } | |
| # Initialize the pipeline | |
| self.pipe = FluxPipeline.from_pretrained( | |
| Config.MODEL_ID, | |
| **pipe_kwargs | |
| ) | |
| # Apply memory optimizations | |
| if Config.ENABLE_MEMORY_EFFICIENT_ATTENTION and self.device == "cuda": | |
| print("Enabling memory efficient attention") | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| if Config.ENABLE_ATTENTION_SLICING: | |
| print("Enabling attention slicing") | |
| self.pipe.enable_attention_slicing(Config.ENABLE_ATTENTION_SLICING) | |
| if Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD and self.device == "cuda": | |
| print("Enabling sequential CPU offload") | |
| self.pipe.enable_sequential_cpu_offload() | |
| else: | |
| # Only move to device if not using CPU offload | |
| self.pipe = self.pipe.to(self.device) | |
| # Load default LoRA | |
| print(f"Loading default LoRA: {Config.DEFAULT_LORA}") | |
| self.pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME) | |
| print("Model initialization complete") | |
| return self.pipe | |
| except Exception as e: | |
| error_msg = f"Error initializing pipeline: {str(e)}" | |
| print(error_msg) | |
| raise | |
| def load_lora(self, lora_path): | |
| """Load a new LoRA model""" | |
| try: | |
| print(f"Unloading previous LoRA weights...") | |
| self.pipe.unload_lora_weights() | |
| if not lora_path: | |
| print("No LoRA path provided, skipping LoRA loading") | |
| return gr.update(value="") | |
| print(f"Loading LoRA from {lora_path}...") | |
| self.pipe.load_lora_weights(lora_path) | |
| print("LoRA loaded successfully") | |
| return gr.update(label="LoRA Loaded Successfully") | |
| except Exception as e: | |
| error_msg = f"Failed to load LoRA from {lora_path}: {str(e)}" | |
| print(error_msg) | |
| raise gr.Error(error_msg) | |
| def _clear_memory(self): | |
| """Clear CUDA memory cache""" | |
| if self.device == "cuda": | |
| try: | |
| print("Clearing CUDA memory cache...") | |
| torch.cuda.empty_cache() | |
| if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): | |
| torch.cuda.amp.clear_autocast_cache() | |
| except Exception as e: | |
| print(f"Warning: Failed to clear CUDA memory: {str(e)}") | |
| def generate(self, prompt, lora_word, lora_scale=Config.DEFAULT_LORA_SCALE, | |
| width=Config.DEFAULT_WIDTH, height=Config.DEFAULT_HEIGHT, | |
| guidance_scale=Config.DEFAULT_GUIDANCE_SCALE, steps=Config.DEFAULT_STEPS, | |
| seed=-1, num_images=1): | |
| """Generate images from a prompt with memory optimizations""" | |
| try: | |
| print(f"Generating image for prompt: '{prompt}'") | |
| # Clear memory before generation | |
| self._clear_memory() | |
| # Ensure we're using the right device | |
| if not Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD: | |
| print(f"Moving model to {self.device}") | |
| self.pipe.to(self.device) | |
| # Handle seed | |
| seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed) | |
| print(f"Using seed: {seed}") | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Translate prompt if not in English | |
| print("Translating prompt if needed...") | |
| prompt_english = str(self.translator.translate(prompt, "English")) | |
| full_prompt = f"{prompt_english} {lora_word}" | |
| print(f"Full prompt: '{full_prompt}'") | |
| # Lower resolution if on limited memory | |
| if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 8 * 1024 * 1024 * 1024: | |
| original_width, original_height = width, height | |
| # Scale down to 85% if memory is tight | |
| width = int(width * 0.85) | |
| height = int(height * 0.85) | |
| print(f"Limited memory detected. Scaling down resolution from {original_width}x{original_height} to {width}x{height}") | |
| # Generate with autocast for memory efficiency | |
| print(f"Starting generation with {steps} steps, guidance scale {guidance_scale}") | |
| with torch.cuda.amp.autocast(enabled=self.device == "cuda"): | |
| result = self.pipe( | |
| prompt=full_prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| output_type="pil", | |
| num_inference_steps=steps, | |
| num_images_per_prompt=num_images, | |
| generator=generator, | |
| joint_attention_kwargs={"scale": lora_scale}, | |
| ) | |
| print("Generation complete, returning images") | |
| self._clear_memory() # Clear memory after generation | |
| return result.images, seed | |
| except Exception as e: | |
| error_msg = f"Image generation failed: {str(e)}" | |
| print(error_msg) | |
| # Clear memory after error | |
| self._clear_memory() | |
| raise gr.Error(error_msg) | |
| # ----------------------------------------------------------------------------- | |
| # UI Builder class | |
| # ----------------------------------------------------------------------------- | |
| class FluxUI: | |
| def __init__(self, generator): | |
| self.generator = generator | |
| self.example_prompts = [ | |
| ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9], | |
| ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9], | |
| ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9], | |
| ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9] | |
| ] | |
| def build(self): | |
| """Build and return the Gradio interface""" | |
| with gr.Blocks(css=Config.CSS) as demo: | |
| gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>") | |
| # Status indicator | |
| processing_status = gr.Markdown("**🟢 Ready**", visible=True) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600) | |
| prompt_input = gr.Textbox( | |
| label="Enter Your Prompt", | |
| lines=2, | |
| placeholder="Enter prompt for your avatar..." | |
| ) | |
| generate_btn = gr.Button(value="Generate", variant="primary") | |
| with gr.Accordion("Advanced Options", open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| width_slider = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=1920, | |
| step=8, | |
| value=Config.DEFAULT_WIDTH | |
| ) | |
| height_slider = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=1920, | |
| step=8, | |
| value=Config.DEFAULT_HEIGHT | |
| ) | |
| with gr.Column(): | |
| guidance_slider = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=3.5, | |
| maximum=7, | |
| step=0.1, | |
| value=Config.DEFAULT_GUIDANCE_SCALE | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=Config.DEFAULT_STEPS | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| seed_slider = gr.Slider( | |
| label="Seed (-1 for random)", | |
| minimum=-1, | |
| maximum=Config.MAX_SEED, | |
| step=1, | |
| value=-1 | |
| ) | |
| nums_slider = gr.Slider( | |
| label="Image Count", | |
| minimum=1, | |
| maximum=2, | |
| step=1, | |
| value=1 | |
| ) | |
| with gr.Column(): | |
| lora_scale_slider = gr.Slider( | |
| label="LoRA Scale", | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.1, | |
| value=Config.DEFAULT_LORA_SCALE | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| lora_add_text = gr.Textbox( | |
| label="Flux LoRA Path", | |
| lines=1, | |
| value=Config.DEFAULT_LORA | |
| ) | |
| with gr.Column(): | |
| lora_word_text = gr.Textbox( | |
| label="Flux LoRA Trigger Word", | |
| lines=1, | |
| value=Config.DEFAULT_TRIGGER_WORD | |
| ) | |
| load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary") | |
| # Memory optimization checkbox | |
| with gr.Row(): | |
| memory_efficient = gr.Checkbox( | |
| label="Enable Memory Optimizations", | |
| value=True, | |
| info="Reduces memory usage but may increase generation time" | |
| ) | |
| # Examples section | |
| gr.Examples( | |
| examples=self.example_prompts, | |
| inputs=[prompt_input, lora_word_text, lora_scale_slider], | |
| cache_examples=False, | |
| examples_per_page=4 | |
| ) | |
| # Wire up the event handlers | |
| # Status update functions | |
| def update_status_processing(): | |
| return "**⏳ Processing...**" | |
| def update_status_done(): | |
| return "**✅ Done!**" | |
| def update_memory_settings(enable_memory_opt): | |
| global Config | |
| Config.ENABLE_MEMORY_EFFICIENT_ATTENTION = enable_memory_opt | |
| Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD = enable_memory_opt | |
| Config.ENABLE_ATTENTION_SLICING = "max" if enable_memory_opt else None | |
| return gr.update() | |
| # Generate button click workflow | |
| generate_btn.click( | |
| fn=update_status_processing, | |
| inputs=[], | |
| outputs=[processing_status] | |
| ).then( | |
| fn=self.generator.generate, | |
| inputs=[ | |
| prompt_input, lora_word_text, lora_scale_slider, | |
| width_slider, height_slider, guidance_slider, | |
| steps_slider, seed_slider, nums_slider | |
| ], | |
| outputs=[gallery, seed_slider] | |
| ).then( | |
| fn=update_status_done, | |
| inputs=[], | |
| outputs=[processing_status] | |
| ) | |
| # Load LoRA button click workflow | |
| load_lora_btn.click( | |
| fn=self.generator.load_lora, | |
| inputs=[lora_add_text], | |
| outputs=[lora_add_text] | |
| ) | |
| # Memory optimization checkbox event | |
| memory_efficient.change( | |
| fn=update_memory_settings, | |
| inputs=[memory_efficient], | |
| outputs=[] | |
| ) | |
| return demo | |
| # ----------------------------------------------------------------------------- | |
| # Main application | |
| # ----------------------------------------------------------------------------- | |
| def main(): | |
| try: | |
| # Create a generator with memory optimizations | |
| generator = FluxGenerator() | |
| # Build and launch UI | |
| ui = FluxUI(generator) | |
| demo = ui.build() | |
| # Launch with low cache size to prevent memory issues | |
| demo.queue(max_size=1).launch(share=False) | |
| except Exception as e: | |
| print(f"Application startup failed: {str(e)}") | |
| # Show error in UI if possible | |
| with gr.Blocks() as error_demo: | |
| gr.Markdown(f"# Error Starting Application\n\n{str(e)}\n\nPlease check the logs for more details.") | |
| gr.Markdown("This might be due to memory limitations or GPU compatibility issues.") | |
| error_demo.launch() |