Upload sCT
Browse files- config.json +1 -1
- config.py +2 -2
- sct.py +0 -7
config.json
CHANGED
|
@@ -18,7 +18,7 @@
|
|
| 18 |
"layer_norm_eps": 1e-05,
|
| 19 |
"mask_token_id": 5,
|
| 20 |
"max_positions": 20480,
|
| 21 |
-
"model_type": "
|
| 22 |
"num_cells": 50,
|
| 23 |
"num_downsamples": 8,
|
| 24 |
"num_hidden_layers_head": 1,
|
|
|
|
| 18 |
"layer_norm_eps": 1e-05,
|
| 19 |
"mask_token_id": 5,
|
| 20 |
"max_positions": 20480,
|
| 21 |
+
"model_type": "sCT",
|
| 22 |
"num_cells": 50,
|
| 23 |
"num_downsamples": 8,
|
| 24 |
"num_hidden_layers_head": 1,
|
config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
from typing import Tuple
|
| 3 |
|
| 4 |
from transformers import PretrainedConfig
|
|
@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class sCTConfig(PretrainedConfig): # noqa: N801
|
| 9 |
-
model_type = "
|
| 10 |
|
| 11 |
def __init__(self, **kwargs): # type: ignore
|
| 12 |
super().__init__()
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
from typing import Tuple
|
| 3 |
|
| 4 |
from transformers import PretrainedConfig
|
|
|
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class sCTConfig(PretrainedConfig): # noqa: N801
|
| 9 |
+
model_type = "sCT"
|
| 10 |
|
| 11 |
def __init__(self, **kwargs): # type: ignore
|
| 12 |
super().__init__()
|
sct.py
CHANGED
|
@@ -672,9 +672,7 @@ class sCT(PreTrainedModel): # noqa: N801
|
|
| 672 |
for _idx, conv_block in enumerate(self.conv_tower):
|
| 673 |
x, res = conv_block(x)
|
| 674 |
residuals.append(res)
|
| 675 |
-
outs["residuals"] = residuals
|
| 676 |
residuals = residuals[::-1]
|
| 677 |
-
conv_block_out = x
|
| 678 |
x = x.permute(0, 2, 1)
|
| 679 |
|
| 680 |
for layer_idx, transformer in enumerate(self.transformer_layers):
|
|
@@ -686,16 +684,11 @@ class sCT(PreTrainedModel): # noqa: N801
|
|
| 686 |
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
|
| 687 |
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
|
| 688 |
outs[dkey] = output["attention_weights"][:, map_number + 1]
|
| 689 |
-
transformer_output = x
|
| 690 |
x = x.permute(0, 2, 1)
|
| 691 |
for deconv_block, res in zip(self.deconv_tower, residuals):
|
| 692 |
x = deconv_block(x, res)
|
| 693 |
-
deconv_block_out = x
|
| 694 |
x = x.permute(0, 2, 1)
|
| 695 |
logits = self.lm_head(x)
|
| 696 |
outs["logits"] = logits
|
| 697 |
-
outs["transformer_output"] = transformer_output
|
| 698 |
-
outs["conv_out"] = conv_block_out
|
| 699 |
-
outs["deconv_out"] = deconv_block_out
|
| 700 |
|
| 701 |
return outs
|
|
|
|
| 672 |
for _idx, conv_block in enumerate(self.conv_tower):
|
| 673 |
x, res = conv_block(x)
|
| 674 |
residuals.append(res)
|
|
|
|
| 675 |
residuals = residuals[::-1]
|
|
|
|
| 676 |
x = x.permute(0, 2, 1)
|
| 677 |
|
| 678 |
for layer_idx, transformer in enumerate(self.transformer_layers):
|
|
|
|
| 684 |
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
|
| 685 |
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
|
| 686 |
outs[dkey] = output["attention_weights"][:, map_number + 1]
|
|
|
|
| 687 |
x = x.permute(0, 2, 1)
|
| 688 |
for deconv_block, res in zip(self.deconv_tower, residuals):
|
| 689 |
x = deconv_block(x, res)
|
|
|
|
| 690 |
x = x.permute(0, 2, 1)
|
| 691 |
logits = self.lm_head(x)
|
| 692 |
outs["logits"] = logits
|
|
|
|
|
|
|
|
|
|
| 693 |
|
| 694 |
return outs
|