Testing2 / app.py
DonImages's picture
Update app.py
50c10d2 verified
raw
history blame
5.6 kB
import gradio as gr
import torch
import os
import random
import numpy as np
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from spaces import GPU # Remove if not in HF Space
# 1. Model and LoRA Loading (Before Gradio)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
try:
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
pipe = pipe.to(device)
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
lora_path = os.path.join("./", lora_filename)
if os.path.exists(lora_path):
lora_weights = load_file(lora_path)
text_encoder = pipe.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
print(f"LoRA loaded successfully from: {lora_path}")
else:
print(f"Error: LoRA file not found at: {lora_path}")
exit() # Stop if LoRA is not found
print("Stable Diffusion model and LoRA loaded successfully!")
except Exception as e:
print(f"Error loading model or LoRA: {e}")
exit()
MAX_SEED = 99999999999
MAX_IMAGE_SIZE = 1024
@GPU(duration=65) # Only if in HF Space
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED) # Generate a new seed if randomize_seed is True
generator = torch.Generator(device=device).manual_seed(seed) # Ensure the generator is on the correct device
try:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed # Don't return seed back to the UI
except Exception as e:
print(f"Error during image generation: {e}") # Print error for debugging
return f"Error: {e}", seed # Return error to Gradio interface
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # [Stable Diffusion 3.5 Large (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) to run locally with ComfyUI or diffusers.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()