Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,362 Bytes
c34a1b5 8d7b14b c34a1b5 8d7b14b c34a1b5 188c627 c34a1b5 a351ff0 c34a1b5 5507320 c34a1b5 8d7b14b a351ff0 8d7b14b ad0688d a351ff0 ad0688d c34a1b5 92b1b58 c34a1b5 796e120 c34a1b5 796e120 c34a1b5 a351ff0 4570550 a351ff0 c34a1b5 a351ff0 796e120 8d7b14b 796e120 8d7b14b c34a1b5 4570550 c34a1b5 a351ff0 c34a1b5 a351ff0 4570550 c34a1b5 4570550 c34a1b5 4570550 c34a1b5 a351ff0 796e120 c34a1b5 4570550 a351ff0 4570550 c34a1b5 4570550 c34a1b5 8d7b14b c34a1b5 4570550 c34a1b5 4570550 c34a1b5 8d7b14b c34a1b5 a351ff0 b46eefe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from functools import lru_cache
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor # Updated per deprecation warning
import os
@lru_cache(maxsize=1)
def load_pipeline():
# Determine device and set torch_dtype accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
# Load base model with the appropriate dtype
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(
base_model,
torch_dtype=torch_dtype
)
# Load LoRA weights
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
pipe.load_lora_weights(lora_repo)
# Load safety checker and image processor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Enable GPU-only optimizations if a GPU is available
if device.type == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print("Warning: Could not enable xformers memory efficient attention:", e)
pipe = pipe.to(device)
return pipe, safety_checker, image_processor
pipe, safety_checker, image_processor = load_pipeline()
def generate_image(
prompt,
seed=42,
width=1024,
height=1024,
guidance_scale=6,
steps=28,
progress=gr.Progress()
):
try:
progress(0, desc="Initializing...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(seed)
# Ensure the trigger word is present
if "super realism" not in prompt.lower():
prompt = f"Super Realism, {prompt}"
# Define a callback to update progress
def update_progress(step, timestep, latents):
progress((step + 1) / steps, desc="Generating image...")
with torch.inference_mode():
result = pipe(
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
callback=update_progress
)
image = result.images[0]
progress(1, desc="Safety checking...")
# Preprocess the image for safety checking
safety_input = image_processor(image, return_tensors="pt")
np_image = np.array(image)
# Run the safety checker
_, nsfw_detected = safety_checker(
images=[np_image],
clip_input=safety_input.pixel_values
)
if nsfw_detected[0]:
return Image.new("RGB", (512, 512)), "NSFW content detected"
return image, "Generation successful"
except Exception as e:
return Image.new("RGB", (512, 512)), f"Error: {str(e)}"
with gr.Blocks() as app:
gr.Markdown("# Flux Super Realism Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person")
seed_input = gr.Slider(0, 1000, value=42, label="Seed")
width_input = gr.Slider(512, 2048, value=1024, label="Width")
height_input = gr.Slider(512, 2048, value=1024, label="Height")
guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale")
steps_input = gr.Slider(10, 100, value=28, label="Steps")
submit = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Result", type="pil")
status = gr.Textbox(label="Status")
submit.click(
generate_image,
inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input],
outputs=[output_image, status]
)
# Queue without GPU-specific arguments
app.queue(max_size=3).launch()
|