File size: 2,419 Bytes
2d58f33 af98809 bafa3e3 d985a28 37c7828 d985a28 9024a26 d985a28 9024a26 d985a28 9024a26 feabc9a e7f18a8 8fed6dd af98809 e7f18a8 af98809 fc20c03 feabc9a 9024a26 37c7828 fc20c03 af98809 fc20c03 e7f18a8 fc20c03 e7f18a8 9024a26 fc20c03 90933cd 9024a26 37c7828 5120226 af98809 d30f159 af98809 feabc9a 37c7828 d30f159 af98809 d30f159 3819d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import spaces
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr
# Retrieve the token from the environment variable
token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
if token:
login(token=token) # Log in with the retrieved token
else:
raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
# Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
model_id = "stabilityai/stable-diffusion-3.5-large"
pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
# Check if GPU is available, then move the model to the appropriate device
pipe.to('cuda' if torch.cuda.is_available() else 'cpu')
# Define the path to the LoRA model
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
def load_lora_model(pipe, lora_model_path):
# Load the LoRA weights
lora_weights = torch.load(lora_model_path, map_location=pipe.device) # Load LoRA model to the correct device
# Print available attributes of the model to check access to `unet` (optional)
print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
# Apply weights to the UNet submodule
try:
for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
if name in lora_weights:
param.data += lora_weights[name]
except AttributeError:
print("The model doesn't have 'unet' attributes. Please check the model structure.")
# Add alternative handling or exit
return pipe
# Load and apply the LoRA model weights
pipe = load_lora_model(pipe, lora_model_path)
# Use the @space.gpu decorator to ensure compatibility with GPU or CPU as needed
@spaces.gpu
def generate(prompt, seed=None):
generator = torch.manual_seed(seed) if seed is not None else None
# Generate the image using the prompt
image = pipe(prompt, height=512, width=512, generator=generator).images[0]
return image
# Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Enter your prompt"), # For the prompt
gr.Number(label="Enter a seed (optional)", value=None), # For the seed
],
outputs="image"
)
iface.launch()
|