File size: 2,646 Bytes
7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 509d782 7b1a432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import torch
import os
from diffusers import StableDiffusion3Pipeline
from safetensors.torch import load_file
from spaces import GPU # Import GPU if in HF Space, otherwise remove this line
# Access HF_TOKEN from environment variables
hf_token = os.getenv("HF_TOKEN")
# Specify the pre-trained model ID
model_id = "stabilityai/stable-diffusion-3.5-large"
# Initialize pipeline *outside* the function (but set to None initially)
pipeline = None
# Function to load the Stable Diffusion pipeline (called only ONCE)
def load_pipeline():
global pipeline # Use the global keyword to modify the global variable
try:
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
use_auth_token=hf_token,
torch_dtype=torch.float16,
cache_dir="./model_cache"
)
except Exception as e:
print(f"Error loading model: {e}")
return f"Error loading model: {e}" # Return error message
pipeline.enable_model_cpu_offload()
pipeline.enable_attention_slicing()
return "Model loaded successfully" # Return success message
# Function for image generation (now decorated)
@GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
def generate_image(prompt):
global pipeline
if pipeline is None: # Check if pipeline is loaded
return "Model not loaded. Please wait." # Return message if not loaded
# Load and apply LoRA (file is already in the Space)
lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
lora_path = os.path.join("./", lora_filename) # Construct the path
print(f"Loading LoRA from: {lora_path}")
try:
if os.path.exists(lora_path): # check if the file exists
lora_weights = load_file(lora_path)
text_encoder = pipeline.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
except Exception as e:
return f"Error loading LoRA: {e}"
try:
image = pipeline(prompt).images[0]
return image
except Exception as e:
return f"Error generating image: {e}"
# Create the Gradio interface
with gr.Blocks() as demo:
prompt_input = gr.Textbox(label="Prompt")
image_output = gr.Image(label="Generated Image")
generate_button = gr.Button("Generate")
load_model_button = gr.Button("Load Model") # Button to load model
load_model_button.click(fn=load_pipeline, outputs=load_model_button) # Call load_pipeline
generate_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=image_output,
)
demo.launch() |