nftnik's picture
Update app.py
634839d verified
raw
history blame
16.8 kB
import os
import random
import torch
import numpy as np
import gradio as gr
import spaces
from diffusers import FluxPipeline
from translatepy import Translator
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
class Config:
MODEL_ID = "black-forest-labs/FLUX.1-dev"
DEFAULT_LORA = "nftnik/BR_ohwx_V1"
DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
MAX_SEED = int(np.iinfo(np.int32).max)
CSS = "footer { visibility: hidden; }"
DEFAULT_WIDTH = 896
DEFAULT_HEIGHT = 1152
DEFAULT_GUIDANCE_SCALE = 3.5
DEFAULT_STEPS = 35
DEFAULT_LORA_SCALE = 1.0
DEFAULT_TRIGGER_WORD = "ohwx"
# Memory optimization configs
ENABLE_MEMORY_EFFICIENT_ATTENTION = True
ENABLE_SEQUENTIAL_CPU_OFFLOAD = True
ENABLE_ATTENTION_SLICING = "max"
# -----------------------------------------------------------------------------
# FluxGenerator class to handle image generation
# -----------------------------------------------------------------------------
class FluxGenerator:
def __init__(self):
# Environment setup
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
self.translator = Translator()
self.device = self._get_optimal_device()
print(f"Using {self.device.upper()}")
# Initialize pipeline
self.pipe = None
self._initialize_pipeline()
def _get_optimal_device(self):
"""Determine the optimal device based on available resources"""
if torch.cuda.is_available():
# Check GPU memory
try:
gpu_memory = torch.cuda.get_device_properties(0).total_memory
if gpu_memory > 10 * 1024 * 1024 * 1024: # More than 10GB
return "cuda"
else:
print("Limited GPU memory detected, using CPU with GPU acceleration")
return "cuda" # Still use CUDA but will apply memory optimizations
except:
print("Error checking GPU memory, falling back to CPU")
return "cpu"
else:
return "cpu"
def _initialize_pipeline(self):
"""Initialize the Flux pipeline with memory optimizations"""
try:
print("Loading Flux model...")
# Use more memory-efficient settings
pipe_kwargs = {
"torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
}
# Initialize the pipeline
self.pipe = FluxPipeline.from_pretrained(
Config.MODEL_ID,
**pipe_kwargs
)
# Apply memory optimizations
if Config.ENABLE_MEMORY_EFFICIENT_ATTENTION and self.device == "cuda":
print("Enabling memory efficient attention")
self.pipe.enable_xformers_memory_efficient_attention()
if Config.ENABLE_ATTENTION_SLICING:
print("Enabling attention slicing")
self.pipe.enable_attention_slicing(Config.ENABLE_ATTENTION_SLICING)
if Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD and self.device == "cuda":
print("Enabling sequential CPU offload")
self.pipe.enable_sequential_cpu_offload()
else:
# Only move to device if not using CPU offload
self.pipe = self.pipe.to(self.device)
# Load default LoRA
print(f"Loading default LoRA: {Config.DEFAULT_LORA}")
self.pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME)
print("Model initialization complete")
return self.pipe
except Exception as e:
error_msg = f"Error initializing pipeline: {str(e)}"
print(error_msg)
raise
def load_lora(self, lora_path):
"""Load a new LoRA model"""
try:
print(f"Unloading previous LoRA weights...")
self.pipe.unload_lora_weights()
if not lora_path:
print("No LoRA path provided, skipping LoRA loading")
return gr.update(value="")
print(f"Loading LoRA from {lora_path}...")
self.pipe.load_lora_weights(lora_path)
print("LoRA loaded successfully")
return gr.update(label="LoRA Loaded Successfully")
except Exception as e:
error_msg = f"Failed to load LoRA from {lora_path}: {str(e)}"
print(error_msg)
raise gr.Error(error_msg)
def _clear_memory(self):
"""Clear CUDA memory cache"""
if self.device == "cuda":
try:
print("Clearing CUDA memory cache...")
torch.cuda.empty_cache()
if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
torch.cuda.amp.clear_autocast_cache()
except Exception as e:
print(f"Warning: Failed to clear CUDA memory: {str(e)}")
@spaces.GPU()
def generate(self, prompt, lora_word, lora_scale=Config.DEFAULT_LORA_SCALE,
width=Config.DEFAULT_WIDTH, height=Config.DEFAULT_HEIGHT,
guidance_scale=Config.DEFAULT_GUIDANCE_SCALE, steps=Config.DEFAULT_STEPS,
seed=-1, num_images=1):
"""Generate images from a prompt with memory optimizations"""
try:
print(f"Generating image for prompt: '{prompt}'")
# Clear memory before generation
self._clear_memory()
# Ensure we're using the right device
if not Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD:
print(f"Moving model to {self.device}")
self.pipe.to(self.device)
# Handle seed
seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed)
print(f"Using seed: {seed}")
generator = torch.Generator(device=self.device).manual_seed(seed)
# Translate prompt if not in English
print("Translating prompt if needed...")
prompt_english = str(self.translator.translate(prompt, "English"))
full_prompt = f"{prompt_english} {lora_word}"
print(f"Full prompt: '{full_prompt}'")
# Lower resolution if on limited memory
if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 8 * 1024 * 1024 * 1024:
original_width, original_height = width, height
# Scale down to 85% if memory is tight
width = int(width * 0.85)
height = int(height * 0.85)
print(f"Limited memory detected. Scaling down resolution from {original_width}x{original_height} to {width}x{height}")
# Generate with autocast for memory efficiency
print(f"Starting generation with {steps} steps, guidance scale {guidance_scale}")
with torch.cuda.amp.autocast(enabled=self.device == "cuda"):
result = self.pipe(
prompt=full_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
output_type="pil",
num_inference_steps=steps,
num_images_per_prompt=num_images,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
)
print("Generation complete, returning images")
self._clear_memory() # Clear memory after generation
return result.images, seed
except Exception as e:
error_msg = f"Image generation failed: {str(e)}"
print(error_msg)
# Clear memory after error
self._clear_memory()
raise gr.Error(error_msg)
# -----------------------------------------------------------------------------
# UI Builder class
# -----------------------------------------------------------------------------
class FluxUI:
def __init__(self, generator):
self.generator = generator
self.example_prompts = [
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
]
def build(self):
"""Build and return the Gradio interface"""
with gr.Blocks(css=Config.CSS) as demo:
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
# Status indicator
processing_status = gr.Markdown("**🟢 Ready**", visible=True)
with gr.Row():
with gr.Column(scale=4):
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
prompt_input = gr.Textbox(
label="Enter Your Prompt",
lines=2,
placeholder="Enter prompt for your avatar..."
)
generate_btn = gr.Button(value="Generate", variant="primary")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
with gr.Column():
width_slider = gr.Slider(
label="Width",
minimum=512,
maximum=1920,
step=8,
value=Config.DEFAULT_WIDTH
)
height_slider = gr.Slider(
label="Height",
minimum=512,
maximum=1920,
step=8,
value=Config.DEFAULT_HEIGHT
)
with gr.Column():
guidance_slider = gr.Slider(
label="Guidance Scale",
minimum=3.5,
maximum=7,
step=0.1,
value=Config.DEFAULT_GUIDANCE_SCALE
)
steps_slider = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
step=1,
value=Config.DEFAULT_STEPS
)
with gr.Row():
with gr.Column():
seed_slider = gr.Slider(
label="Seed (-1 for random)",
minimum=-1,
maximum=Config.MAX_SEED,
step=1,
value=-1
)
nums_slider = gr.Slider(
label="Image Count",
minimum=1,
maximum=2,
step=1,
value=1
)
with gr.Column():
lora_scale_slider = gr.Slider(
label="LoRA Scale",
minimum=0.1,
maximum=2.0,
step=0.1,
value=Config.DEFAULT_LORA_SCALE
)
with gr.Row():
with gr.Column():
lora_add_text = gr.Textbox(
label="Flux LoRA Path",
lines=1,
value=Config.DEFAULT_LORA
)
with gr.Column():
lora_word_text = gr.Textbox(
label="Flux LoRA Trigger Word",
lines=1,
value=Config.DEFAULT_TRIGGER_WORD
)
load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary")
# Memory optimization checkbox
with gr.Row():
memory_efficient = gr.Checkbox(
label="Enable Memory Optimizations",
value=True,
info="Reduces memory usage but may increase generation time"
)
# Examples section
gr.Examples(
examples=self.example_prompts,
inputs=[prompt_input, lora_word_text, lora_scale_slider],
cache_examples=False,
examples_per_page=4
)
# Wire up the event handlers
# Status update functions
def update_status_processing():
return "**⏳ Processing...**"
def update_status_done():
return "**✅ Done!**"
def update_memory_settings(enable_memory_opt):
global Config
Config.ENABLE_MEMORY_EFFICIENT_ATTENTION = enable_memory_opt
Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD = enable_memory_opt
Config.ENABLE_ATTENTION_SLICING = "max" if enable_memory_opt else None
return gr.update()
# Generate button click workflow
generate_btn.click(
fn=update_status_processing,
inputs=[],
outputs=[processing_status]
).then(
fn=self.generator.generate,
inputs=[
prompt_input, lora_word_text, lora_scale_slider,
width_slider, height_slider, guidance_slider,
steps_slider, seed_slider, nums_slider
],
outputs=[gallery, seed_slider]
).then(
fn=update_status_done,
inputs=[],
outputs=[processing_status]
)
# Load LoRA button click workflow
load_lora_btn.click(
fn=self.generator.load_lora,
inputs=[lora_add_text],
outputs=[lora_add_text]
)
# Memory optimization checkbox event
memory_efficient.change(
fn=update_memory_settings,
inputs=[memory_efficient],
outputs=[]
)
return demo
# -----------------------------------------------------------------------------
# Main application
# -----------------------------------------------------------------------------
def main():
try:
# Create a generator with memory optimizations
generator = FluxGenerator()
# Build and launch UI
ui = FluxUI(generator)
demo = ui.build()
# Launch with low cache size to prevent memory issues
demo.queue(max_size=1).launch(share=False)
except Exception as e:
print(f"Application startup failed: {str(e)}")
# Show error in UI if possible
with gr.Blocks() as error_demo:
gr.Markdown(f"# Error Starting Application\n\n{str(e)}\n\nPlease check the logs for more details.")
gr.Markdown("This might be due to memory limitations or GPU compatibility issues.")
error_demo.launch()