Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from diffusers import DiffusionPipeline | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from functools import lru_cache | |
| from PIL import Image | |
| from transformers import CLIPImageProcessor | |
| def load_pipeline(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Use FP16 when CUDA is available, along with a revision flag if supported. | |
| torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| revision = "fp16" if device.type == "cuda" else None | |
| base_model = "black-forest-labs/FLUX.1-dev" | |
| pipe = DiffusionPipeline.from_pretrained( | |
| base_model, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| revision=revision, | |
| ) | |
| # Load LoRA weights | |
| lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA" | |
| pipe.load_lora_weights(lora_repo) | |
| # Load safety checker and image processor. | |
| # If memory remains an issue, you can disable the safety checker below. | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| "CompVis/stable-diffusion-safety-checker" | |
| ) | |
| image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| if device.type == "cuda": | |
| # Use attention slicing for further memory savings. | |
| pipe.enable_attention_slicing() | |
| # Offload layers to CPU when not in use. | |
| pipe.enable_sequential_cpu_offload() | |
| return pipe, safety_checker, image_processor | |
| pipe, safety_checker, image_processor = load_pipeline() | |
| def generate_image( | |
| prompt, | |
| seed=42, | |
| width=512, # Keep resolution low by default | |
| height=512, | |
| guidance_scale=6, | |
| steps=28, | |
| progress=gr.Progress() | |
| ): | |
| try: | |
| progress(0, desc="Initializing...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| # Auto-add trigger words if not present | |
| if "super realism" not in prompt.lower(): | |
| prompt = f"Super Realism, {prompt}" | |
| with torch.inference_mode(): | |
| result = pipe( | |
| prompt=prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| ) | |
| image = result.images[0] | |
| progress(1, desc="Safety checking...") | |
| # Process image for safety checking | |
| safety_input = image_processor(image, return_tensors="pt") | |
| np_image = np.array(image) | |
| _, nsfw_detected = safety_checker( | |
| images=[np_image], | |
| clip_input=safety_input.pixel_values | |
| ) | |
| if nsfw_detected[0]: | |
| return Image.new("RGB", (width, height)), "NSFW content detected" | |
| # Clear CUDA cache | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return image, "Generation successful" | |
| except Exception as e: | |
| return Image.new("RGB", (width, height)), f"Error: {str(e)}" | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Flux Super Realism Generator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt", value="A portrait of a person") | |
| seed_input = gr.Slider(0, 1000, value=42, label="Seed") | |
| # Limit the resolution sliders to help avoid memory overuse. | |
| width_input = gr.Slider(256, 1024, value=512, step=64, label="Width") | |
| height_input = gr.Slider(256, 1024, value=512, step=64, label="Height") | |
| guidance_input = gr.Slider(1, 20, value=6, label="Guidance Scale") | |
| steps_input = gr.Slider(10, 100, value=28, label="Steps") | |
| submit = gr.Button("Generate") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Result", type="pil") | |
| status = gr.Textbox(label="Status") | |
| submit.click( | |
| generate_image, | |
| inputs=[prompt_input, seed_input, width_input, height_input, guidance_input, steps_input], | |
| outputs=[output_image, status] | |
| ) | |
| # Queue settings to limit concurrent requests | |
| app.queue(max_size=3).launch() | |