File size: 4,649 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
188c627
a351ff0
5e66759
c34a1b5
5507320
c34a1b5
8d7b14b
5e66759
a351ff0
 
 
8d7b14b
ad0688d
 
5e66759
 
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
796e120
c34a1b5
 
 
796e120
c34a1b5
5e66759
4570550
a351ff0
 
 
 
5e66759
 
 
 
 
c34a1b5
a351ff0
796e120
8d7b14b
796e120
8d7b14b
c34a1b5
 
 
5e66759
 
c34a1b5
 
 
 
 
 
4570550
 
c34a1b5
5e66759
c34a1b5
 
 
4570550
c34a1b5
 
 
4570550
c34a1b5
 
 
 
 
 
 
4570550
 
c34a1b5
 
796e120
c34a1b5
4570550
5e66759
4570550
5e66759
c34a1b5
 
 
4570550
c34a1b5
 
 
 
 
 
8d7b14b
c34a1b5
 
 
 
 
4570550
 
5e66759
 
 
4570550
 
c34a1b5
 
 
 
 
 
 
 
4570550
c34a1b5
8d7b14b
c34a1b5
5e66759
b46eefe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image

from transformers import CLIPImageProcessor  # Updated per deprecation warning



@lru_cache(maxsize=1)
def load_pipeline():
    # Determine device and appropriate torch_dtype
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True  # Reduce memory usage during load
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker and image processor
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Enable GPU optimizations if on GPU; else, try sequential offloading on CPU
    if device.type == "cuda":
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception as e:
            print("Warning: Could not enable xformers memory efficient attention:", e)
    else:
        try:
            pipe.enable_sequential_cpu_offload()
        except Exception as e:
            print("Warning: Could not enable sequential CPU offload:", e)
    
    pipe = pipe.to(device)
    return pipe, safety_checker, image_processor

pipe, safety_checker, image_processor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=512,   # Lowered default resolution
    height=512,  # Lowered default resolution
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        generator = torch.Generator(device=device).manual_seed(seed)
        
        # Auto-add the trigger word if not already present
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        def update_progress(step, timestep, latents):
            progress((step + 1) / steps, desc="Generating image...")
        
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
                callback=update_progress
            )
            image = result.images[0]
        
        progress(1, desc="Safety checking...")
        safety_input = image_processor(image, return_tensors="pt")
        np_image = np.array(image)
        
        # Run safety checker; it returns a tuple where the second element is nsfw flags
        _, nsfw_detected = safety_checker(
            images=[np_image],
            clip_input=safety_input.pixel_values
        )
        
        if nsfw_detected[0]:
            return Image.new("RGB", (512, 512)), "NSFW content detected"
        
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (512, 512)), f"Error: {str(e)}"

with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed_input = gr.Slider(0, 1000, value=42, label="Seed")
            # Lower the resolution slider range for less memory-intensive generation
            width_input = gr.Slider(256, 1024, value=512, label="Width")
            height_input = gr.Slider(256, 1024, value=512, label="Height")
            guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps_input = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
        outputs=[output_image, status]
    )
    
    # Use queue without GPU-specific parameters
    app.queue(max_size=3).launch()