File size: 2,309 Bytes
feabc9a bafa3e3 d985a28 37c7828 d985a28 c868357 9024a26 d985a28 9024a26 d985a28 9024a26 feabc9a 9b047cb 8fed6dd 9b047cb 3819d16 fc20c03 feabc9a 9024a26 37c7828 fc20c03 3819d16 fc20c03 9024a26 fc20c03 9024a26 fc20c03 90933cd 9024a26 37c7828 5120226 9024a26 d30f159 9b047cb feabc9a 37c7828 d30f159 3819d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr
# Check for GPU availability and set device accordingly
if torch.cuda.is_available():
device = "cuda"
print("GPU is available")
else:
device = "cpu"
print("GPU is not available, using CPU")
# Retrieve the token from the environment variable
token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
if token:
login(token=token) # Log in with the retrieved token
else:
raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
# Load the Stable Diffusion 3.5 model with lower precision (float16)
model_id = "stabilityai/stable-diffusion-3.5-large"
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) # Use float16 precision
# Check for GPU availability and set device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device) # Use GPU if available, otherwise fallback to CPU
# Define the path to the LoRA model
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
def load_lora_model(pipe, lora_model_path):
# Load the LoRA weights
lora_weights = torch.load(lora_model_path, map_location=device) # Use correct device
# Apply weights to the UNet submodule
for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
if name in lora_weights:
param.data += lora_weights[name]
return pipe
# Load and apply the LoRA model weights
pipe = load_lora_model(pipe, lora_model_path)
# Function to generate an image from a text prompt
def generate_image(prompt, seed=None):
generator = torch.manual_seed(seed) if seed is not None else None
# Reduce image size for less memory usage
image = pipe(prompt, height=512, width=512, generator=generator).images[0] # Changed image size
return image
# Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Enter your prompt"), # For the prompt
gr.Number(label="Enter a seed (optional)", value=None), # For the seed
],
outputs="image"
)
iface.launch()
|