File size: 5,006 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
188c627
b994d40
 
 
c34a1b5
8d7b14b
2872348
 
 
 
 
8d7b14b
ad0688d
 
2872348
 
 
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
796e120
c34a1b5
 
 
796e120
c34a1b5
2872348
4570550
2872348
 
 
 
 
 
 
b994d40
796e120
8d7b14b
796e120
8d7b14b
c34a1b5
 
 
2872348
 
c34a1b5
 
 
 
 
 
4570550
 
c34a1b5
b994d40
c34a1b5
 
 
2872348
 
 
b994d40
c34a1b5
4570550
c34a1b5
 
 
 
 
b994d40
4570550
 
c34a1b5
b994d40
 
796e120
c34a1b5
4570550
b994d40
4570550
b994d40
c34a1b5
 
 
4570550
c34a1b5
 
 
 
 
 
8d7b14b
c34a1b5
 
 
 
 
4570550
 
2872348
 
 
4570550
 
c34a1b5
 
 
 
 
 
 
 
4570550
c34a1b5
8d7b14b
c34a1b5
b994d40
b46eefe
b994d40
2872348
b994d40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image

from torchvision import transforms
from transformers import CLIPImageProcessor  # Updated import

@lru_cache(maxsize=1)
def load_pipeline():
    # Decide on torch_dtype based on device; use fp16 on CUDA to lower memory usage.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = torch.float16 if device.type == "cuda" else torch.float32

    # Load the base model in the selected precision
    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch_dtype,
        # low_cpu_mem_usage helps reduce CPU RAM usage during loading
        low_cpu_mem_usage=True
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker and image processor
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # If using CUDA, apply memory optimizations:
    if device.type == "cuda":
        # Attention slicing splits up attention computations to save memory.
        pipe.enable_attention_slicing()
        # Instead of moving the entire model to GPU, offload parts to CPU when not needed.
        # This is particularly useful on a 15GB GPU.
        pipe.enable_model_cpu_offload()
        # Note: xformers memory efficient attention is omitted here because
        # model offload works best when not all weights are kept on GPU.
    
    return pipe, safety_checker, image_processor

pipe, safety_checker, image_processor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=512,   # default resolution adjusted to 512 for safety
    height=512,
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        generator = torch.Generator(device=device).manual_seed(seed)
        
        # Auto-add trigger words if not present
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        # Optional: you could add a progress callback here if your pipeline supports it.
        # def update_progress(step, timestep, latents):
        #     progress((step + 1) / steps, desc="Generating image...")
        
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
            )
            image = result.images[0]
        
        progress(1, desc="Safety checking...")
        # Preprocess image for safety checking using the updated image processor
        safety_input = image_processor(image, return_tensors="pt")
        np_image = np.array(image)
        
        # Unpack safety checker results
        _, nsfw_detected = safety_checker(
            images=[np_image], 
            clip_input=safety_input.pixel_values
        )
        
        if nsfw_detected[0]:
            return Image.new("RGB", (512, 512)), "NSFW content detected"
        
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (512, 512)), f"Error: {str(e)}"

with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed_input = gr.Slider(0, 1000, value=42, label="Seed")
            # Limit resolution sliders to help avoid GPU memory overuse on a 15GB A100
            width_input = gr.Slider(256, 1024, value=512, step=64, label="Width")
            height_input = gr.Slider(256, 1024, value=512, step=64, label="Height")
            guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps_input = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
        outputs=[output_image, status]
    )
    
    # Rate limiting: 1 request at a time, with a max queue size of 3
    app.queue(max_size=3).launch()

# Advanced multiple GPU support (uncomment if needed):
# pipe.enable_sequential_cpu_offload()