Upload Flowformer
Browse files- model_flowformer.py +2 -2
model_flowformer.py
CHANGED
|
@@ -96,7 +96,7 @@ class Flowformer(PreTrainedModel):
|
|
| 96 |
self.dec = nn.Sequential(*dec_layers)
|
| 97 |
|
| 98 |
def markers(self):
|
| 99 |
-
return self.
|
| 100 |
|
| 101 |
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
|
| 102 |
r"""
|
|
@@ -110,7 +110,7 @@ class Flowformer(PreTrainedModel):
|
|
| 110 |
"""
|
| 111 |
B, L, M = tensor.shape
|
| 112 |
if markers is not None:
|
| 113 |
-
assert len(markers) == M, "
|
| 114 |
|
| 115 |
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
|
| 116 |
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
|
|
|
|
| 96 |
self.dec = nn.Sequential(*dec_layers)
|
| 97 |
|
| 98 |
def markers(self):
|
| 99 |
+
return self._markers
|
| 100 |
|
| 101 |
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
|
| 102 |
r"""
|
|
|
|
| 110 |
"""
|
| 111 |
B, L, M = tensor.shape
|
| 112 |
if markers is not None:
|
| 113 |
+
assert len(markers) == M, "last dimension of input must be equal to number of markers"
|
| 114 |
|
| 115 |
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
|
| 116 |
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
|