Spaces:
Running
Running
Antoni Bigata
commited on
Commit
·
4fd1a69
1
Parent(s):
fc0dc6f
requirements
Browse files- WavLM.py +1 -1
- sgm/models/diffusion.py +5 -5
WavLM.py
CHANGED
|
@@ -48,7 +48,7 @@ class WavLM_wrapper(nn.Module):
|
|
| 48 |
)
|
| 49 |
if not os.path.exists(model_path):
|
| 50 |
self.download_model(model_path, model_size)
|
| 51 |
-
checkpoint = torch.load(model_path)
|
| 52 |
cfg = WavLMConfig(checkpoint["cfg"])
|
| 53 |
self.cfg = cfg
|
| 54 |
self.model = WavLM(cfg)
|
|
|
|
| 48 |
)
|
| 49 |
if not os.path.exists(model_path):
|
| 50 |
self.download_model(model_path, model_size)
|
| 51 |
+
checkpoint = torch.load(model_path, weights_only=False)
|
| 52 |
cfg = WavLMConfig(checkpoint["cfg"])
|
| 53 |
self.cfg = cfg
|
| 54 |
self.model = WavLM(cfg)
|
sgm/models/diffusion.py
CHANGED
|
@@ -119,7 +119,7 @@ class DiffusionEngine(pl.LightningModule):
|
|
| 119 |
pattern_to_remove=pattern_to_remove,
|
| 120 |
)
|
| 121 |
if separate_unet_ckpt is not None:
|
| 122 |
-
sd = torch.load(separate_unet_ckpt)["state_dict"]
|
| 123 |
if remove_keys_from_unet_weights is not None:
|
| 124 |
for k in list(sd.keys()):
|
| 125 |
for remove_key in remove_keys_from_unet_weights:
|
|
@@ -190,7 +190,7 @@ class DiffusionEngine(pl.LightningModule):
|
|
| 190 |
|
| 191 |
def load_bad_model_weights(self, path: str) -> None:
|
| 192 |
print(f"Restoring bad model from {path}")
|
| 193 |
-
state_dict = torch.load(path, map_location="cpu")
|
| 194 |
new_dict = {}
|
| 195 |
for k, v in state_dict["module"].items():
|
| 196 |
if "learned_mask" in k:
|
|
@@ -221,13 +221,13 @@ class DiffusionEngine(pl.LightningModule):
|
|
| 221 |
) -> None:
|
| 222 |
print(f"Restoring from {path}")
|
| 223 |
if path.endswith("ckpt"):
|
| 224 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 225 |
elif path.endswith("pt"):
|
| 226 |
-
sd = torch.load(path, map_location="cpu")["module"]
|
| 227 |
# Remove leading _forward_module from keys
|
| 228 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
| 229 |
elif path.endswith("bin"):
|
| 230 |
-
sd = torch.load(path, map_location="cpu")
|
| 231 |
# Remove leading _forward_module from keys
|
| 232 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
| 233 |
elif path.endswith("safetensors"):
|
|
|
|
| 119 |
pattern_to_remove=pattern_to_remove,
|
| 120 |
)
|
| 121 |
if separate_unet_ckpt is not None:
|
| 122 |
+
sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"]
|
| 123 |
if remove_keys_from_unet_weights is not None:
|
| 124 |
for k in list(sd.keys()):
|
| 125 |
for remove_key in remove_keys_from_unet_weights:
|
|
|
|
| 190 |
|
| 191 |
def load_bad_model_weights(self, path: str) -> None:
|
| 192 |
print(f"Restoring bad model from {path}")
|
| 193 |
+
state_dict = torch.load(path, map_location="cpu", weights_only=False)
|
| 194 |
new_dict = {}
|
| 195 |
for k, v in state_dict["module"].items():
|
| 196 |
if "learned_mask" in k:
|
|
|
|
| 221 |
) -> None:
|
| 222 |
print(f"Restoring from {path}")
|
| 223 |
if path.endswith("ckpt"):
|
| 224 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
|
| 225 |
elif path.endswith("pt"):
|
| 226 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)["module"]
|
| 227 |
# Remove leading _forward_module from keys
|
| 228 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
| 229 |
elif path.endswith("bin"):
|
| 230 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)
|
| 231 |
# Remove leading _forward_module from keys
|
| 232 |
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
| 233 |
elif path.endswith("safetensors"):
|