File size: 2,722 Bytes
feabc9a
bafa3e3
d985a28
37c7828
 
d985a28
c868357
 
 
 
 
 
 
 
9024a26
 
d985a28
9024a26
d985a28
9024a26
feabc9a
e7f18a8
8fed6dd
e7f18a8
 
 
 
 
8871d09
fc20c03
 
 
feabc9a
9024a26
37c7828
fc20c03
8871d09
fc20c03
e7f18a8
 
 
fc20c03
e7f18a8
 
 
 
 
 
 
9024a26
fc20c03
90933cd
9024a26
37c7828
5120226
9024a26
d30f159
 
9b047cb
 
feabc9a
 
37c7828
d30f159
 
 
 
 
 
 
 
3819d16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from diffusers import StableDiffusion3Pipeline
from huggingface_hub import login
import os
import gradio as gr

# Check for GPU availability and set device accordingly
if torch.cuda.is_available():
    device = "cuda"
    print("GPU is available")
else:
    device = "cpu"
    print("GPU is not available, using CPU")

# Retrieve the token from the environment variable
token = os.getenv("HF_TOKEN")  # Hugging Face token from the secret
if token:
    login(token=token)  # Log in with the retrieved token
else:
    raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")

# Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
model_id = "stabilityai/stable-diffusion-3.5-large"
if device == "cuda":
    pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)  # Use float16 precision
else:
    pipe = StableDiffusion3Pipeline.from_pretrained(model_id)  # Default precision for CPU

pipe.to(device)  # Ensuring the model is on the correct device (GPU or CPU)

# Define the path to the LoRA model
lora_model_path = "./lora_model.pth"  # Assuming the file is saved locally

# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
def load_lora_model(pipe, lora_model_path):
    # Load the LoRA weights
    lora_weights = torch.load(lora_model_path, map_location=device)  # Load LoRA model to the correct device
    
    # Print available attributes of the model to check access to `unet` (optional)
    print(dir(pipe))  # This will list all attributes and methods of the `pipe` object

    # Apply weights to the UNet submodule
    try:
        for name, param in pipe.unet.named_parameters():  # Accessing unet parameters
            if name in lora_weights:
                param.data += lora_weights[name]
    except AttributeError:
        print("The model doesn't have 'unet' attributes. Please check the model structure.")
        # Add alternative handling or exit

    return pipe

# Load and apply the LoRA model weights
pipe = load_lora_model(pipe, lora_model_path)

# Function to generate an image from a text prompt
def generate_image(prompt, seed=None):
    generator = torch.manual_seed(seed) if seed is not None else None
    # Reduce image size for less memory usage
    image = pipe(prompt, height=512, width=512, generator=generator).images[0]  # Changed image size
    return image

# Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Enter your prompt"),  # For the prompt
        gr.Number(label="Enter a seed (optional)", value=None),  # For the seed
    ],
    outputs="image"
)
iface.launch()