Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
from diffusers import DiffusionPipeline | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
from functools import lru_cache | |
from PIL import Image | |
from torchvision import transforms | |
# Cache pipeline loading to improve performance | |
def load_pipeline(): | |
# Load base model | |
base_model = "black-forest-labs/FLUX.1-dev" | |
pipe = DiffusionPipeline.from_pretrained( | |
base_model, | |
torch_dtype=torch.bfloat16 | |
) | |
# Load LoRA weights | |
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA" | |
pipe.load_lora_weights(lora_repo) | |
# Load safety checker | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
# Optimizations | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe = pipe.to("cuda") | |
return pipe, safety_checker, feature_extractor | |
pipe, safety_checker, feature_extractor = load_pipeline() | |
def generate_image( | |
prompt, | |
seed=42, | |
width=1024, | |
height=1024, | |
guidance_scale=6, | |
steps=28, | |
progress=gr.Progress() | |
): | |
try: | |
progress(0, desc="Initializing...") | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Auto-add trigger words | |
if "super realism" not in prompt.lower(): | |
prompt = f"Super Realism, {prompt}" | |
# Create callback for progress updates | |
def update_progress(step, _, __): | |
progress((step + 1) / steps, desc="Generating image...") | |
# Generate image | |
with torch.inference_mode(): | |
image = pipe( | |
prompt=prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=steps, | |
generator=generator, | |
callback=update_progress | |
).images[0] | |
# Safety check | |
progress(1, desc="Safety checking...") | |
safety_input = feature_extractor(image, return_tensors="pt") | |
np_image = np.array(image) | |
safety_result = safety_checker( | |
images=[np_image], | |
clip_input=safety_input.pixel_values | |
) | |
if safety_result.nsfw[0]: | |
return Image.new("RGB", (512, 512)), "NSFW content detected" | |
return image, "Generation successful" | |
except Exception as e: | |
return Image.new("RGB", (512, 512)), f"Error: {str(e)}" | |
# Create Gradio interface with rate limiting | |
with gr.Blocks() as app: | |
gr.Markdown("# Flux Super Realism Generator") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="A portrait of a person") | |
seed = gr.Slider(0, 1000, value=42, label="Seed") | |
width = gr.Slider(512, 2048, value=1024, label="Width") | |
height = gr.Slider(512, 2048, value=1024, label="Height") | |
guidance = gr.Slider(1, 20, value=6, label="Guidance Scale") | |
steps = gr.Slider(10, 100, value=28, label="Steps") | |
submit = gr.Button("Generate") | |
with gr.Column(): | |
output_image = gr.Image(label="Result", type="pil") | |
status = gr.Textbox(label="Status") | |
submit.click( | |
generate_image, | |
inputs=[prompt, seed, width, height, guidance, steps], | |
outputs=[output_image, status] | |
) | |
# Rate limiting example (1 request every 30 seconds) | |
app.queue(concurrency_count=1, max_size=3).launch() | |
# For multiple GPU support (advanced) | |
# pipe.enable_model_cpu_offload() | |
# pipe.enable_sequential_cpu_offload() |