Update app.py
Browse files
app.py
CHANGED
@@ -17,29 +17,27 @@ model_id = "stabilityai/stable-diffusion-3.5-large"
|
|
17 |
pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
|
18 |
|
19 |
# Check if GPU is available, then move the model to the appropriate device
|
20 |
-
|
|
|
21 |
|
22 |
# Define the path to the LoRA model
|
23 |
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
24 |
|
25 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
26 |
def load_lora_model(pipe, lora_model_path):
|
27 |
-
# Set device to 'cuda' if available, otherwise 'cpu'
|
28 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
-
|
30 |
# When loading the LoRA weights
|
31 |
lora_weights = torch.load(lora_model_path, map_location=device, weights_only=True)
|
|
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
|
35 |
-
|
36 |
-
# Apply weights to the UNet submodule
|
37 |
try:
|
38 |
-
for name, param in pipe.
|
39 |
if name in lora_weights:
|
40 |
param.data += lora_weights[name]
|
41 |
except AttributeError:
|
42 |
-
print("The model doesn't have '
|
43 |
# Add alternative handling or exit
|
44 |
|
45 |
return pipe
|
@@ -47,7 +45,7 @@ def load_lora_model(pipe, lora_model_path):
|
|
47 |
# Load and apply the LoRA model weights
|
48 |
pipe = load_lora_model(pipe, lora_model_path)
|
49 |
|
50 |
-
# Use the @
|
51 |
@spaces.gpu
|
52 |
def generate(prompt, seed=None):
|
53 |
generator = torch.manual_seed(seed) if seed is not None else None
|
|
|
17 |
pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
|
18 |
|
19 |
# Check if GPU is available, then move the model to the appropriate device
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
pipe.to(device)
|
22 |
|
23 |
# Define the path to the LoRA model
|
24 |
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
25 |
|
26 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
27 |
def load_lora_model(pipe, lora_model_path):
|
|
|
|
|
|
|
28 |
# When loading the LoRA weights
|
29 |
lora_weights = torch.load(lora_model_path, map_location=device, weights_only=True)
|
30 |
+
|
31 |
+
# Check if the transformer folder has the necessary attributes
|
32 |
+
print(dir(pipe.transformer)) # List available attributes of the transformer (formerly 'unet')
|
33 |
|
34 |
+
# Apply weights to the transformer submodule
|
|
|
|
|
|
|
35 |
try:
|
36 |
+
for name, param in pipe.transformer.named_parameters(): # Accessing transformer parameters
|
37 |
if name in lora_weights:
|
38 |
param.data += lora_weights[name]
|
39 |
except AttributeError:
|
40 |
+
print("The model doesn't have 'transformer' attributes. Please check the model structure.")
|
41 |
# Add alternative handling or exit
|
42 |
|
43 |
return pipe
|
|
|
45 |
# Load and apply the LoRA model weights
|
46 |
pipe = load_lora_model(pipe, lora_model_path)
|
47 |
|
48 |
+
# Use the @spaces.gpu decorator to ensure compatibility with GPU or CPU as needed
|
49 |
@spaces.gpu
|
50 |
def generate(prompt, seed=None):
|
51 |
generator = torch.manual_seed(seed) if seed is not None else None
|