Testing2 / app.py
DonImages's picture
Update app.py
509d782 verified
raw
history blame
2.65 kB
import gradio as gr
import torch
import os
from diffusers import StableDiffusion3Pipeline
from safetensors.torch import load_file
from spaces import GPU # Import GPU if in HF Space, otherwise remove this line
# Access HF_TOKEN from environment variables
hf_token = os.getenv("HF_TOKEN")
# Specify the pre-trained model ID
model_id = "stabilityai/stable-diffusion-3.5-large"
# Initialize pipeline *outside* the function (but set to None initially)
pipeline = None
# Function to load the Stable Diffusion pipeline (called only ONCE)
def load_pipeline():
global pipeline # Use the global keyword to modify the global variable
try:
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
use_auth_token=hf_token,
torch_dtype=torch.float16,
cache_dir="./model_cache"
)
except Exception as e:
print(f"Error loading model: {e}")
return f"Error loading model: {e}" # Return error message
pipeline.enable_model_cpu_offload()
pipeline.enable_attention_slicing()
return "Model loaded successfully" # Return success message
# Function for image generation (now decorated)
@GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
def generate_image(prompt):
global pipeline
if pipeline is None: # Check if pipeline is loaded
return "Model not loaded. Please wait." # Return message if not loaded
# Load and apply LoRA (file is already in the Space)
lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
lora_path = os.path.join("./", lora_filename) # Construct the path
print(f"Loading LoRA from: {lora_path}")
try:
if os.path.exists(lora_path): # check if the file exists
lora_weights = load_file(lora_path)
text_encoder = pipeline.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
except Exception as e:
return f"Error loading LoRA: {e}"
try:
image = pipeline(prompt).images[0]
return image
except Exception as e:
return f"Error generating image: {e}"
# Create the Gradio interface
with gr.Blocks() as demo:
prompt_input = gr.Textbox(label="Prompt")
image_output = gr.Image(label="Generated Image")
generate_button = gr.Button("Generate")
load_model_button = gr.Button("Load Model") # Button to load model
load_model_button.click(fn=load_pipeline, outputs=load_model_button) # Call load_pipeline
generate_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=image_output,
)
demo.launch()