File size: 2,722 Bytes
feabc9a bafa3e3 d985a28 37c7828 d985a28 c868357 9024a26 d985a28 9024a26 d985a28 9024a26 feabc9a e7f18a8 8fed6dd e7f18a8 8871d09 fc20c03 feabc9a 9024a26 37c7828 fc20c03 8871d09 fc20c03 e7f18a8 fc20c03 e7f18a8 9024a26 fc20c03 90933cd 9024a26 37c7828 5120226 9024a26 d30f159 9b047cb feabc9a 37c7828 d30f159 3819d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr
# Check for GPU availability and set device accordingly
if torch.cuda.is_available():
device = "cuda"
print("GPU is available")
else:
device = "cpu"
print("GPU is not available, using CPU")
# Retrieve the token from the environment variable
token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
if token:
login(token=token) # Log in with the retrieved token
else:
raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
# Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
model_id = "stabilityai/stable-diffusion-3.5-large"
if device == "cuda":
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) # Use float16 precision
else:
pipe = StableDiffusion3Pipeline.from_pretrained(model_id) # Default precision for CPU
pipe.to(device) # Ensuring the model is on the correct device (GPU or CPU)
# Define the path to the LoRA model
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
def load_lora_model(pipe, lora_model_path):
# Load the LoRA weights
lora_weights = torch.load(lora_model_path, map_location=device) # Load LoRA model to the correct device
# Print available attributes of the model to check access to `unet` (optional)
print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
# Apply weights to the UNet submodule
try:
for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
if name in lora_weights:
param.data += lora_weights[name]
except AttributeError:
print("The model doesn't have 'unet' attributes. Please check the model structure.")
# Add alternative handling or exit
return pipe
# Load and apply the LoRA model weights
pipe = load_lora_model(pipe, lora_model_path)
# Function to generate an image from a text prompt
def generate_image(prompt, seed=None):
generator = torch.manual_seed(seed) if seed is not None else None
# Reduce image size for less memory usage
image = pipe(prompt, height=512, width=512, generator=generator).images[0] # Changed image size
return image
# Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Enter your prompt"), # For the prompt
gr.Number(label="Enter a seed (optional)", value=None), # For the seed
],
outputs="image"
)
iface.launch()
|