File size: 2,661 Bytes
7b1a432 02a3a52 7b1a432 dfab3a9 02a3a52 dfab3a9 7b1a432 dfab3a9 509d782 dfab3a9 dc25832 02a3a52 7b1a432 dfab3a9 4a6aac0 7b1a432 dc25832 dfab3a9 dc25832 dfab3a9 dc25832 dfab3a9 dc25832 dfab3a9 b9c75da dc25832 b9c75da dfab3a9 7b1a432 b9c75da 7b1a432 b9c75da 7b1a432 dfab3a9 7b1a432 509d782 7b1a432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
import torch
import os
from diffusers import StableDiffusion3Pipeline
from safetensors.torch import load_file
from spaces import GPU # Remove if not in HF Space
# 1. Define model ID and HF_TOKEN (at the VERY beginning)
model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID
hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
# 2. Initialize pipeline (to None initially)
pipeline = None
# 3. Load Stable Diffusion and LoRA (before Gradio)
try:
if hf_token: # check if the token exists, if not, then do not pass the token
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
cache_dir="./model_cache" # For caching
)
else:
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
cache_dir="./model_cache" # For caching
)
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
lora_path = os.path.join("./", lora_filename)
if os.path.exists(lora_path):
lora_weights = load_file(lora_path)
text_encoder = pipeline.text_encoder
text_encoder.load_state_dict(lora_weights, strict=False)
print(f"LoRA loaded successfully from: {lora_path}")
else:
print(f"Error: LoRA file not found at: {lora_path}")
exit() # Stop if LoRA is not found
print("Stable Diffusion model loaded successfully!")
except Exception as e:
print(f"Error loading model or LoRA: {e}")
exit() # Stop if model loading fails
# 4. Image generation function (now decorated)
@GPU(duration=65) # Only if in HF Space
def generate_image(prompt):
global pipeline
if pipeline is None:
return "Error: Model not loaded!"
try:
image = pipeline(prompt).images[0] # Try to generate the image
print("Image generated successfully!") # Print success message (for debugging)
return image # Return the image if successful
except Exception as e:
error_message = f"Error during image generation: {e}" # Capture error
print(error_message) # Print error message to console
return error_message # Return the error message to Gradio (so it shows up)
# 5. Gradio interface
with gr.Blocks() as demo:
prompt_input = gr.Textbox(label="Prompt")
image_output = gr.Image(label="Generated Image")
generate_button = gr.Button("Generate")
generate_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=image_output,
)
demo.launch() |