File size: 4,306 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
4570550
c34a1b5
4570550
 
c34a1b5
5507320
c34a1b5
8d7b14b
c34a1b5
8d7b14b
ad0688d
 
c34a1b5
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
4570550
c34a1b5
 
 
 
 
 
 
4570550
 
 
 
 
c34a1b5
 
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
 
 
 
 
 
 
 
 
4570550
 
c34a1b5
4570550
c34a1b5
 
 
4570550
 
c34a1b5
 
 
4570550
c34a1b5
 
 
 
 
 
 
4570550
 
c34a1b5
 
4570550
c34a1b5
 
4570550
 
 
c34a1b5
 
 
 
4570550
c34a1b5
 
 
 
 
 
8d7b14b
c34a1b5
 
 
 
 
4570550
 
 
 
 
 
c34a1b5
 
 
 
 
 
 
 
4570550
c34a1b5
8d7b14b
c34a1b5
4570550
b46eefe
4570550
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image

from torchvision import transforms
from transformers import CLIPFeatureExtractor  # Added missing import



@lru_cache(maxsize=1)
def load_pipeline():
    # Load base model
    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker and feature extractor
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )
    
    # Optimizations: enable memory efficient attention if using GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to(device)
    
    return pipe, safety_checker, feature_extractor

pipe, safety_checker, feature_extractor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=1024,
    height=1024,
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        generator = torch.Generator(device=device).manual_seed(seed)
        
        # Auto-add trigger words if not present
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        # Define the callback function with the proper signature
        def update_progress(step, timestep, latents):
            progress((step + 1) / steps, desc="Generating image...")
        
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
                callback=update_progress
            )
            image = result.images[0]
        
        progress(1, desc="Safety checking...")
        # Preprocess image for safety checking
        safety_input = feature_extractor(image, return_tensors="pt")
        np_image = np.array(image)
        
        # Unpack safety checker results (the safety checker returns a tuple)
        _, nsfw_detected = safety_checker(
            images=[np_image], 
            clip_input=safety_input.pixel_values
        )
        
        if nsfw_detected[0]:
            return Image.new("RGB", (512, 512)), "NSFW content detected"
        
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (512, 512)), f"Error: {str(e)}"

with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed_input = gr.Slider(0, 1000, value=42, label="Seed")
            width_input = gr.Slider(512, 2048, value=1024, label="Width")
            height_input = gr.Slider(512, 2048, value=1024, label="Height")
            guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps_input = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
        outputs=[output_image, status]
    )
    
    # Rate limiting: 1 request at a time, with a max queue size of 3
    app.queue(max_size=3).launch()
# Uncomment the lines below for advanced multiple GPU support
pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()