File size: 5,595 Bytes
7b1a432 c93b55a 7b1a432 02a3a52 7b1a432 c93b55a 7b1a432 dc25832 c93b55a 7b1a432 dfab3a9 4a6aac0 7b1a432 dc25832 c93b55a dc25832 dfab3a9 dc25832 dfab3a9 dc25832 c93b55a dc25832 dfab3a9 c93b55a 50c10d2 c93b55a a51aea0 d2d8e3a dfee4d1 d2d8e3a c93b55a dfee4d1 c93b55a dfee4d1 c93b55a 81b1867 79d9e62 7b1a432 79d9e62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import gradio as gr
import torch
import os
import random
import numpy as np
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from spaces import GPU # Remove if not in HF Space
# 1. Model and LoRA Loading (Before Gradio)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
try:
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
pipe = pipe.to(device)
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
lora_path = os.path.join("./", lora_filename)
if os.path.exists(lora_path):
lora_weights = load_file(lora_path)
text_encoder = pipe.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
print(f"LoRA loaded successfully from: {lora_path}")
else:
print(f"Error: LoRA file not found at: {lora_path}")
exit() # Stop if LoRA is not found
print("Stable Diffusion model and LoRA loaded successfully!")
except Exception as e:
print(f"Error loading model or LoRA: {e}")
exit()
MAX_SEED = 99999999999
MAX_IMAGE_SIZE = 1024
@GPU(duration=65) # Only if in HF Space
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED) # Generate a new seed if randomize_seed is True
generator = torch.Generator(device=device).manual_seed(seed) # Ensure the generator is on the correct device
try:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed # Don't return seed back to the UI
except Exception as e:
print(f"Error during image generation: {e}") # Print error for debugging
return f"Error: {e}", seed # Return error to Gradio interface
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # [Stable Diffusion 3.5 Large (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) to run locally with ComfyUI or diffusers.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()
|