Update modeling_quiet.py
Browse files- modeling_quiet.py +8 -8
modeling_quiet.py
CHANGED
@@ -55,7 +55,7 @@ from transformers.utils import (
|
|
55 |
logging,
|
56 |
replace_return_docstrings,
|
57 |
)
|
58 |
-
from
|
59 |
|
60 |
|
61 |
if is_flash_attn_2_available():
|
@@ -67,7 +67,7 @@ if is_flash_attn_2_available():
|
|
67 |
|
68 |
logger = logging.get_logger(__name__)
|
69 |
|
70 |
-
_CONFIG_FOR_DOC = "
|
71 |
|
72 |
from reportlab.pdfgen import canvas
|
73 |
from reportlab.lib.pagesizes import letter
|
@@ -270,7 +270,7 @@ class QuietAttention(nn.Module):
|
|
270 |
and "Generating Long Sequences with Sparse Transformers".
|
271 |
"""
|
272 |
|
273 |
-
def __init__(self, config:
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
@@ -818,7 +818,7 @@ QUIET_ATTENTION_CLASSES = {
|
|
818 |
|
819 |
|
820 |
class QuietDecoderLayer(nn.Module):
|
821 |
-
def __init__(self, config:
|
822 |
super().__init__()
|
823 |
self.hidden_size = config.hidden_size
|
824 |
|
@@ -896,7 +896,7 @@ QUIET_START_DOCSTRING = r"""
|
|
896 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
897 |
and behavior.
|
898 |
Parameters:
|
899 |
-
config ([`
|
900 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
901 |
load the weights associated with the model, only the configuration. Check out the
|
902 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -908,7 +908,7 @@ QUIET_START_DOCSTRING = r"""
|
|
908 |
QUIET_START_DOCSTRING,
|
909 |
)
|
910 |
class QuietPreTrainedModel(PreTrainedModel):
|
911 |
-
config_class =
|
912 |
base_model_prefix = "model"
|
913 |
supports_gradient_checkpointing = True
|
914 |
_no_split_modules = ["QuietDecoderLayer"]
|
@@ -995,10 +995,10 @@ class QuietModel(QuietPreTrainedModel):
|
|
995 |
"""
|
996 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
|
997 |
Args:
|
998 |
-
config:
|
999 |
"""
|
1000 |
|
1001 |
-
def __init__(self, config:
|
1002 |
super().__init__(config)
|
1003 |
self.padding_idx = config.pad_token_id
|
1004 |
self.vocab_size = config.vocab_size
|
|
|
55 |
logging,
|
56 |
replace_return_docstrings,
|
57 |
)
|
58 |
+
from .configuration_quiet import QuietConfig
|
59 |
|
60 |
|
61 |
if is_flash_attn_2_available():
|
|
|
67 |
|
68 |
logger = logging.get_logger(__name__)
|
69 |
|
70 |
+
_CONFIG_FOR_DOC = "QuietConfig"
|
71 |
|
72 |
from reportlab.pdfgen import canvas
|
73 |
from reportlab.lib.pagesizes import letter
|
|
|
270 |
and "Generating Long Sequences with Sparse Transformers".
|
271 |
"""
|
272 |
|
273 |
+
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
|
|
818 |
|
819 |
|
820 |
class QuietDecoderLayer(nn.Module):
|
821 |
+
def __init__(self, config: QuietConfig, layer_idx: int):
|
822 |
super().__init__()
|
823 |
self.hidden_size = config.hidden_size
|
824 |
|
|
|
896 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
897 |
and behavior.
|
898 |
Parameters:
|
899 |
+
config ([`QuietConfig`]):
|
900 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
901 |
load the weights associated with the model, only the configuration. Check out the
|
902 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
908 |
QUIET_START_DOCSTRING,
|
909 |
)
|
910 |
class QuietPreTrainedModel(PreTrainedModel):
|
911 |
+
config_class = QuietConfig
|
912 |
base_model_prefix = "model"
|
913 |
supports_gradient_checkpointing = True
|
914 |
_no_split_modules = ["QuietDecoderLayer"]
|
|
|
995 |
"""
|
996 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
|
997 |
Args:
|
998 |
+
config: QuietConfig
|
999 |
"""
|
1000 |
|
1001 |
+
def __init__(self, config: QuietConfig):
|
1002 |
super().__init__(config)
|
1003 |
self.padding_idx = config.pad_token_id
|
1004 |
self.vocab_size = config.vocab_size
|