Testing2 / app.py
DonImages's picture
Update app.py
6ee05ba verified
raw
history blame
1.72 kB
import torch
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr
# Retrieve Hugging Face token
token = os.getenv("HF_TOKEN")
if token:
login(token=token)
else:
raise ValueError("Hugging Face token not found. Please set it as a repository secret.")
# Load the Stable Diffusion 3.5 model
model_id = "stabilityai/stable-diffusion-3.5-large"
pipe = StableDiffusion3Pipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="balanced"
)
# Enable attention slicing for reduced memory usage
pipe.enable_attention_slicing()
# Define the path to the LoRA model
lora_model_path = "./lora_model.pth"
# Load and apply the LoRA weights
def load_lora_model(pipe, lora_model_path):
lora_weights = torch.load(lora_model_path, map_location="cuda")
pipe.unet.load_attn_procs(lora_weights)
return pipe
pipe = load_lora_model(pipe, lora_model_path)
# Generate image function
def generate_image(prompt, steps, scale):
with torch.inference_mode(): # Avoid gradient computation for inference
image = pipe(prompt, num_inference_steps=steps, guidance_scale=scale).images[0]
return image
# Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Enter your prompt"),
gr.Slider(10, 50, step=1, value=30, label="Number of Inference Steps"),
gr.Slider(1.0, 20.0, step=0.5, value=7.5, label="Guidance Scale"),
],
outputs="image",
title="Optimized Stable Diffusion with LoRA",
description="Generate images using Stable Diffusion 3.5 with optimized memory usage."
)
# Launch the Gradio interface
iface.launch()