File size: 16,769 Bytes
cc0fe43
19be4eb
2685d15
 
 
 
 
 
51468b8
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc0fe43
 
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51468b8
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c7bd7c
2685d15
 
 
 
 
 
 
 
 
 
cc0fe43
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51468b8
 
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a7021
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a7021
 
2685d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634839d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import os
import random
import torch
import numpy as np
import gradio as gr
import spaces
from diffusers import FluxPipeline
from translatepy import Translator

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
class Config:
    MODEL_ID = "black-forest-labs/FLUX.1-dev"
    DEFAULT_LORA = "nftnik/BR_ohwx_V1"
    DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
    MAX_SEED = int(np.iinfo(np.int32).max)
    CSS = "footer { visibility: hidden; }"
    DEFAULT_WIDTH = 896
    DEFAULT_HEIGHT = 1152
    DEFAULT_GUIDANCE_SCALE = 3.5
    DEFAULT_STEPS = 35
    DEFAULT_LORA_SCALE = 1.0
    DEFAULT_TRIGGER_WORD = "ohwx"
    # Memory optimization configs
    ENABLE_MEMORY_EFFICIENT_ATTENTION = True
    ENABLE_SEQUENTIAL_CPU_OFFLOAD = True
    ENABLE_ATTENTION_SLICING = "max"


# -----------------------------------------------------------------------------
# FluxGenerator class to handle image generation
# -----------------------------------------------------------------------------
class FluxGenerator:
    def __init__(self):
        # Environment setup
        os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
        self.translator = Translator()
        self.device = self._get_optimal_device()
        print(f"Using {self.device.upper()}")
        
        # Initialize pipeline
        self.pipe = None
        self._initialize_pipeline()
        
    def _get_optimal_device(self):
        """Determine the optimal device based on available resources"""
        if torch.cuda.is_available():
            # Check GPU memory
            try:
                gpu_memory = torch.cuda.get_device_properties(0).total_memory
                if gpu_memory > 10 * 1024 * 1024 * 1024:  # More than 10GB
                    return "cuda"
                else:
                    print("Limited GPU memory detected, using CPU with GPU acceleration")
                    return "cuda"  # Still use CUDA but will apply memory optimizations
            except:
                print("Error checking GPU memory, falling back to CPU")
                return "cpu"
        else:
            return "cpu"
            
    def _initialize_pipeline(self):
        """Initialize the Flux pipeline with memory optimizations"""
        try:
            print("Loading Flux model...")
            # Use more memory-efficient settings
            pipe_kwargs = {
                "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
            }
            
            # Initialize the pipeline
            self.pipe = FluxPipeline.from_pretrained(
                Config.MODEL_ID,
                **pipe_kwargs
            )
            
            # Apply memory optimizations
            if Config.ENABLE_MEMORY_EFFICIENT_ATTENTION and self.device == "cuda":
                print("Enabling memory efficient attention")
                self.pipe.enable_xformers_memory_efficient_attention()
            
            if Config.ENABLE_ATTENTION_SLICING:
                print("Enabling attention slicing")
                self.pipe.enable_attention_slicing(Config.ENABLE_ATTENTION_SLICING)
            
            if Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD and self.device == "cuda":
                print("Enabling sequential CPU offload")
                self.pipe.enable_sequential_cpu_offload()
            else:
                # Only move to device if not using CPU offload
                self.pipe = self.pipe.to(self.device)
            
            # Load default LoRA
            print(f"Loading default LoRA: {Config.DEFAULT_LORA}")
            self.pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME)
            
            print("Model initialization complete")
            return self.pipe
            
        except Exception as e:
            error_msg = f"Error initializing pipeline: {str(e)}"
            print(error_msg)
            raise

    def load_lora(self, lora_path):
        """Load a new LoRA model"""
        try:
            print(f"Unloading previous LoRA weights...")
            self.pipe.unload_lora_weights()
            
            if not lora_path:
                print("No LoRA path provided, skipping LoRA loading")
                return gr.update(value="")
                
            print(f"Loading LoRA from {lora_path}...")
            self.pipe.load_lora_weights(lora_path)
            print("LoRA loaded successfully")
            return gr.update(label="LoRA Loaded Successfully")
            
        except Exception as e:
            error_msg = f"Failed to load LoRA from {lora_path}: {str(e)}"
            print(error_msg)
            raise gr.Error(error_msg)

    def _clear_memory(self):
        """Clear CUDA memory cache"""
        if self.device == "cuda":
            try:
                print("Clearing CUDA memory cache...")
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
                    torch.cuda.amp.clear_autocast_cache()
            except Exception as e:
                print(f"Warning: Failed to clear CUDA memory: {str(e)}")

    @spaces.GPU()
    def generate(self, prompt, lora_word, lora_scale=Config.DEFAULT_LORA_SCALE,
                 width=Config.DEFAULT_WIDTH, height=Config.DEFAULT_HEIGHT,
                 guidance_scale=Config.DEFAULT_GUIDANCE_SCALE, steps=Config.DEFAULT_STEPS,
                 seed=-1, num_images=1):
        """Generate images from a prompt with memory optimizations"""
        try:
            print(f"Generating image for prompt: '{prompt}'")
            
            # Clear memory before generation
            self._clear_memory()
            
            # Ensure we're using the right device
            if not Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD:
                print(f"Moving model to {self.device}")
                self.pipe.to(self.device)
            
            # Handle seed
            seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed)
            print(f"Using seed: {seed}")
            generator = torch.Generator(device=self.device).manual_seed(seed)
            
            # Translate prompt if not in English
            print("Translating prompt if needed...")
            prompt_english = str(self.translator.translate(prompt, "English"))
            full_prompt = f"{prompt_english} {lora_word}"
            print(f"Full prompt: '{full_prompt}'")
            
            # Lower resolution if on limited memory
            if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 8 * 1024 * 1024 * 1024:
                original_width, original_height = width, height
                # Scale down to 85% if memory is tight
                width = int(width * 0.85)
                height = int(height * 0.85)
                print(f"Limited memory detected. Scaling down resolution from {original_width}x{original_height} to {width}x{height}")
            
            # Generate with autocast for memory efficiency
            print(f"Starting generation with {steps} steps, guidance scale {guidance_scale}")
            with torch.cuda.amp.autocast(enabled=self.device == "cuda"):
                result = self.pipe(
                    prompt=full_prompt,
                    height=height,
                    width=width,
                    guidance_scale=guidance_scale,
                    output_type="pil",
                    num_inference_steps=steps,
                    num_images_per_prompt=num_images,
                    generator=generator,
                    joint_attention_kwargs={"scale": lora_scale},
                )
            
            print("Generation complete, returning images")
            self._clear_memory()  # Clear memory after generation
            return result.images, seed
            
        except Exception as e:
            error_msg = f"Image generation failed: {str(e)}"
            print(error_msg)
            # Clear memory after error
            self._clear_memory()
            raise gr.Error(error_msg)


# -----------------------------------------------------------------------------
# UI Builder class
# -----------------------------------------------------------------------------
class FluxUI:
    def __init__(self, generator):
        self.generator = generator
        self.example_prompts = [
            ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
            ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
            ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
            ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
        ]
        
    def build(self):
        """Build and return the Gradio interface"""
        with gr.Blocks(css=Config.CSS) as demo:
            gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
            
            # Status indicator
            processing_status = gr.Markdown("**🟒 Ready**", visible=True)
            
            with gr.Row():
                with gr.Column(scale=4):
                    gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
                    prompt_input = gr.Textbox(
                        label="Enter Your Prompt",
                        lines=2,
                        placeholder="Enter prompt for your avatar..."
                    )
                    generate_btn = gr.Button(value="Generate", variant="primary")
                    
                with gr.Accordion("Advanced Options", open=True):
                    with gr.Row():
                        with gr.Column():
                            width_slider = gr.Slider(
                                label="Width",
                                minimum=512,
                                maximum=1920,
                                step=8,
                                value=Config.DEFAULT_WIDTH
                            )
                            height_slider = gr.Slider(
                                label="Height",
                                minimum=512,
                                maximum=1920,
                                step=8,
                                value=Config.DEFAULT_HEIGHT
                            )
                        with gr.Column():
                            guidance_slider = gr.Slider(
                                label="Guidance Scale",
                                minimum=3.5,
                                maximum=7,
                                step=0.1,
                                value=Config.DEFAULT_GUIDANCE_SCALE
                            )
                            steps_slider = gr.Slider(
                                label="Steps",
                                minimum=1,
                                maximum=100,
                                step=1,
                                value=Config.DEFAULT_STEPS
                            )
                    
                    with gr.Row():
                        with gr.Column():
                            seed_slider = gr.Slider(
                                label="Seed (-1 for random)",
                                minimum=-1,
                                maximum=Config.MAX_SEED,
                                step=1,
                                value=-1
                            )
                            nums_slider = gr.Slider(
                                label="Image Count",
                                minimum=1,
                                maximum=2,
                                step=1,
                                value=1
                            )
                        with gr.Column():
                            lora_scale_slider = gr.Slider(
                                label="LoRA Scale",
                                minimum=0.1,
                                maximum=2.0,
                                step=0.1,
                                value=Config.DEFAULT_LORA_SCALE
                            )
                            
                    with gr.Row():
                        with gr.Column():
                            lora_add_text = gr.Textbox(
                                label="Flux LoRA Path",
                                lines=1,
                                value=Config.DEFAULT_LORA
                            )
                        with gr.Column():
                            lora_word_text = gr.Textbox(
                                label="Flux LoRA Trigger Word",
                                lines=1,
                                value=Config.DEFAULT_TRIGGER_WORD
                            )
                    
                    load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary")
                    
                    # Memory optimization checkbox
                    with gr.Row():
                        memory_efficient = gr.Checkbox(
                            label="Enable Memory Optimizations",
                            value=True,
                            info="Reduces memory usage but may increase generation time"
                        )
            
            # Examples section
            gr.Examples(
                examples=self.example_prompts,
                inputs=[prompt_input, lora_word_text, lora_scale_slider],
                cache_examples=False,
                examples_per_page=4
            )
            
            # Wire up the event handlers
            # Status update functions
            def update_status_processing():
                return "**⏳ Processing...**"
                
            def update_status_done():
                return "**βœ… Done!**"
                
            def update_memory_settings(enable_memory_opt):
                global Config
                Config.ENABLE_MEMORY_EFFICIENT_ATTENTION = enable_memory_opt
                Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD = enable_memory_opt
                Config.ENABLE_ATTENTION_SLICING = "max" if enable_memory_opt else None
                return gr.update()

            # Generate button click workflow
            generate_btn.click(
                fn=update_status_processing,
                inputs=[],
                outputs=[processing_status]
            ).then(
                fn=self.generator.generate,
                inputs=[
                    prompt_input, lora_word_text, lora_scale_slider,
                    width_slider, height_slider, guidance_slider,
                    steps_slider, seed_slider, nums_slider
                ],
                outputs=[gallery, seed_slider]
            ).then(
                fn=update_status_done,
                inputs=[],
                outputs=[processing_status]
            )
            
            # Load LoRA button click workflow
            load_lora_btn.click(
                fn=self.generator.load_lora,
                inputs=[lora_add_text],
                outputs=[lora_add_text]
            )
            
            # Memory optimization checkbox event
            memory_efficient.change(
                fn=update_memory_settings,
                inputs=[memory_efficient],
                outputs=[]
            )
            
        return demo


# -----------------------------------------------------------------------------
# Main application
# -----------------------------------------------------------------------------
def main():
    try:
        # Create a generator with memory optimizations
        generator = FluxGenerator()
        
        # Build and launch UI
        ui = FluxUI(generator)
        demo = ui.build()
        
        # Launch with low cache size to prevent memory issues
        demo.queue(max_size=1).launch(share=False)
        
    except Exception as e:
        print(f"Application startup failed: {str(e)}")
        # Show error in UI if possible
        with gr.Blocks() as error_demo:
            gr.Markdown(f"# Error Starting Application\n\n{str(e)}\n\nPlease check the logs for more details.")
            gr.Markdown("This might be due to memory limitations or GPU compatibility issues.")
            error_demo.launch()