File size: 4,335 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
a37a20c
b994d40
c34a1b5
8d7b14b
2872348
a37a20c
2872348
a37a20c
2872348
8d7b14b
ad0688d
 
2872348
a37a20c
 
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
a37a20c
 
c34a1b5
 
 
796e120
c34a1b5
4570550
a37a20c
2872348
a37a20c
 
b994d40
796e120
8d7b14b
796e120
8d7b14b
c34a1b5
 
 
a37a20c
2872348
c34a1b5
 
 
 
 
 
4570550
 
c34a1b5
b994d40
c34a1b5
 
 
 
4570550
c34a1b5
 
 
 
 
b994d40
4570550
 
c34a1b5
b994d40
a37a20c
796e120
c34a1b5
4570550
a37a20c
c34a1b5
 
 
4570550
a37a20c
c34a1b5
a37a20c
 
 
c34a1b5
 
 
a37a20c
8d7b14b
c34a1b5
 
 
 
 
4570550
 
a37a20c
2872348
 
4570550
 
c34a1b5
 
 
 
 
 
 
 
4570550
c34a1b5
8d7b14b
c34a1b5
a37a20c
b46eefe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image
from transformers import CLIPImageProcessor

@lru_cache(maxsize=1)
def load_pipeline():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Use FP16 when CUDA is available, along with a revision flag if supported.
    torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
    revision = "fp16" if device.type == "cuda" else None

    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        revision=revision,
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker and image processor.
    # If memory remains an issue, you can disable the safety checker below.
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    if device.type == "cuda":
        # Use attention slicing for further memory savings.
        pipe.enable_attention_slicing()
        # Offload layers to CPU when not in use.
        pipe.enable_sequential_cpu_offload()
    
    return pipe, safety_checker, image_processor

pipe, safety_checker, image_processor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=512,   # Keep resolution low by default
    height=512,
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        generator = torch.Generator(device=device).manual_seed(seed)
        
        # Auto-add trigger words if not present
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
            )
            image = result.images[0]
        
        progress(1, desc="Safety checking...")
        # Process image for safety checking
        safety_input = image_processor(image, return_tensors="pt")
        np_image = np.array(image)
        _, nsfw_detected = safety_checker(
            images=[np_image],
            clip_input=safety_input.pixel_values
        )
        
        if nsfw_detected[0]:
            return Image.new("RGB", (width, height)), "NSFW content detected"
        
        # Clear CUDA cache
        if device.type == "cuda":
            torch.cuda.empty_cache()
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (width, height)), f"Error: {str(e)}"

with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed_input = gr.Slider(0, 1000, value=42, label="Seed")
            # Limit the resolution sliders to help avoid memory overuse.
            width_input = gr.Slider(256, 1024, value=512, step=64, label="Width")
            height_input = gr.Slider(256, 1024, value=512, step=64, label="Height")
            guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps_input = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
        outputs=[output_image, status]
    )
    
    # Queue settings to limit concurrent requests
    app.queue(max_size=3).launch()