File size: 4,362 Bytes
c34a1b5
8d7b14b
c34a1b5
8d7b14b
c34a1b5
 
 
188c627
c34a1b5
a351ff0
 
c34a1b5
5507320
c34a1b5
8d7b14b
a351ff0
 
 
 
 
8d7b14b
ad0688d
 
a351ff0
ad0688d
c34a1b5
 
 
92b1b58
c34a1b5
796e120
c34a1b5
 
 
796e120
c34a1b5
a351ff0
4570550
a351ff0
 
 
 
c34a1b5
a351ff0
796e120
8d7b14b
796e120
8d7b14b
c34a1b5
 
 
 
 
 
 
 
 
 
 
4570550
 
c34a1b5
a351ff0
c34a1b5
 
 
a351ff0
4570550
c34a1b5
 
 
4570550
c34a1b5
 
 
 
 
 
 
4570550
 
c34a1b5
 
a351ff0
796e120
c34a1b5
4570550
a351ff0
4570550
c34a1b5
 
 
 
4570550
c34a1b5
 
 
 
 
 
8d7b14b
c34a1b5
 
 
 
 
4570550
 
 
 
 
 
c34a1b5
 
 
 
 
 
 
 
4570550
c34a1b5
8d7b14b
c34a1b5
a351ff0
b46eefe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image

from torchvision import transforms
from transformers import CLIPImageProcessor  # Updated per deprecation warning
import os


@lru_cache(maxsize=1)
def load_pipeline():
    # Determine device and set torch_dtype accordingly
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

    # Load base model with the appropriate dtype
    base_model = "black-forest-labs/FLUX.1-dev"
    pipe = DiffusionPipeline.from_pretrained(
        base_model,
        torch_dtype=torch_dtype
    )
    
    # Load LoRA weights
    lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
    pipe.load_lora_weights(lora_repo)
    
    # Load safety checker and image processor
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Enable GPU-only optimizations if a GPU is available
    if device.type == "cuda":
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception as e:
            print("Warning: Could not enable xformers memory efficient attention:", e)
    
    pipe = pipe.to(device)
    return pipe, safety_checker, image_processor

pipe, safety_checker, image_processor = load_pipeline()

def generate_image(
    prompt,
    seed=42,
    width=1024,
    height=1024,
    guidance_scale=6,
    steps=28,
    progress=gr.Progress()
):
    try:
        progress(0, desc="Initializing...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        generator = torch.Generator(device=device).manual_seed(seed)
        
        # Ensure the trigger word is present
        if "super realism" not in prompt.lower():
            prompt = f"Super Realism, {prompt}"
        
        # Define a callback to update progress
        def update_progress(step, timestep, latents):
            progress((step + 1) / steps, desc="Generating image...")
        
        with torch.inference_mode():
            result = pipe(
                prompt=prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=steps,
                generator=generator,
                callback=update_progress
            )
            image = result.images[0]
        
        progress(1, desc="Safety checking...")
        # Preprocess the image for safety checking
        safety_input = image_processor(image, return_tensors="pt")
        np_image = np.array(image)
        
        # Run the safety checker
        _, nsfw_detected = safety_checker(
            images=[np_image], 
            clip_input=safety_input.pixel_values
        )
        
        if nsfw_detected[0]:
            return Image.new("RGB", (512, 512)), "NSFW content detected"
        
        return image, "Generation successful"
    
    except Exception as e:
        return Image.new("RGB", (512, 512)), f"Error: {str(e)}"

with gr.Blocks() as app:
    gr.Markdown("# Flux Super Realism Generator")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
            seed_input = gr.Slider(0, 1000, value=42, label="Seed")
            width_input = gr.Slider(512, 2048, value=1024, label="Width")
            height_input = gr.Slider(512, 2048, value=1024, label="Height")
            guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
            steps_input = gr.Slider(10, 100, value=28, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output_image = gr.Image(label="Result", type="pil")
            status = gr.Textbox(label="Status")
    
    submit.click(
        generate_image,
        inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
        outputs=[output_image, status]
    )
    
    # Queue without GPU-specific arguments
    app.queue(max_size=3).launch()