File size: 3,832 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
 
 
5507320
c34a1b5
 
8d7b14b
c34a1b5
8d7b14b
ad0688d
 
c34a1b5
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7b14b
c34a1b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7b14b
c34a1b5
 
 
8d7b14b
c34a1b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image
from torchvision import transforms


# Cache pipeline loading to improve performance
@lru_cache(maxsize=1)
def load_pipeline():
    # Load base model
    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch.bfloat16
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )
    
    # Optimizations
    pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to("cuda")
    
    return pipe, safety_checker, feature_extractor

pipe, safety_checker, feature_extractor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=1024,
    height=1024,
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        generator = torch.Generator(device="cuda").manual_seed(seed)
        
        # Auto-add trigger words
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        # Create callback for progress updates
        def update_progress(step, _, __):
            progress((step + 1) / steps, desc="Generating image...")
        
        # Generate image
        with torch.inference_mode():
            image = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
                callback=update_progress
            ).images[0]
        
        # Safety check
        progress(1, desc="Safety checking...")
        safety_input = feature_extractor(image, return_tensors="pt")
        np_image = np.array(image)
        safety_result = safety_checker(
            images=[np_image], 
            clip_input=safety_input.pixel_values
        )
        
        if safety_result.nsfw[0]:
            return Image.new("RGB", (512, 512)), "NSFW content detected"
        
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (512, 512)), f"Error: {str(e)}"

# Create Gradio interface with rate limiting
with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed = gr.Slider(0, 1000, value=42, label="Seed")
            width = gr.Slider(512, 2048, value=1024, label="Width")
            height = gr.Slider(512, 2048, value=1024, label="Height")
            guidance = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt, seed, width, height, guidance, steps],
        outputs=[output_image, status]
    )
    
    # Rate limiting example (1 request every 30 seconds)
    app.queue(concurrency_count=1, max_size=3).launch()

# For multiple GPU support (advanced)
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()