Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,302 Bytes
b55d767 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn as nn
from utmosv2.dataset._utils import get_dataset_num
from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel
class SSLMultiSpecExtModelV1(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ssl = SSLExtModel(cfg)
self.spec_long = MultiSpecModelV2(cfg)
self.ssl.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
self.spec_long.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.freeze:
for param in self.ssl.parameters():
param.requires_grad = False
for param in self.spec_long.parameters():
param.requires_grad = False
ssl_input = self.ssl.fc.in_features
spec_long_input = self.spec_long.fc.in_features
self.ssl.fc = nn.Identity()
self.spec_long.fc = nn.Identity()
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
ssl_input + spec_long_input + self.num_dataset,
cfg.model.ssl_spec.num_classes,
)
def forward(self, x1, x2, d):
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
x2 = self.spec_long(x2)
x = torch.cat([x1, x2, d], dim=1)
x = self.fc(x)
return x
class SSLMultiSpecExtModelV2(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ssl = SSLExtModel(cfg)
self.spec_long = MultiSpecExtModel(cfg)
if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train":
self.ssl.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train":
self.spec_long.load_state_dict(
torch.load(
f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
)
)
if cfg.model.ssl_spec.freeze:
for param in self.ssl.parameters():
param.requires_grad = False
for param in self.spec_long.parameters():
param.requires_grad = False
ssl_input = self.ssl.fc.in_features
spec_long_input = self.spec_long.fc.in_features
self.ssl.fc = nn.Identity()
self.spec_long.fc = nn.Identity()
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
ssl_input + spec_long_input + self.num_dataset,
cfg.model.ssl_spec.num_classes,
)
def forward(self, x1, x2, d):
x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
x2 = self.spec_long(
x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)
)
x = torch.cat([x1, x2, d], dim=1)
x = self.fc(x)
return x
|