Crystalcareai commited on
Commit
3e8d756
·
verified ·
1 Parent(s): d1ca91a

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +8 -8
modeling_quiet.py CHANGED
@@ -55,7 +55,7 @@ from transformers.utils import (
55
  logging,
56
  replace_return_docstrings,
57
  )
58
- from transformers import AutoConfig
59
 
60
 
61
  if is_flash_attn_2_available():
@@ -67,7 +67,7 @@ if is_flash_attn_2_available():
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
- _CONFIG_FOR_DOC = "AutoConfig"
71
 
72
  from reportlab.pdfgen import canvas
73
  from reportlab.lib.pagesizes import letter
@@ -270,7 +270,7 @@ class QuietAttention(nn.Module):
270
  and "Generating Long Sequences with Sparse Transformers".
271
  """
272
 
273
- def __init__(self, config: AutoConfig, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
@@ -818,7 +818,7 @@ QUIET_ATTENTION_CLASSES = {
818
 
819
 
820
  class QuietDecoderLayer(nn.Module):
821
- def __init__(self, config: AutoConfig, layer_idx: int):
822
  super().__init__()
823
  self.hidden_size = config.hidden_size
824
 
@@ -896,7 +896,7 @@ QUIET_START_DOCSTRING = r"""
896
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
897
  and behavior.
898
  Parameters:
899
- config ([`AutoConfig`]):
900
  Model configuration class with all the parameters of the model. Initializing with a config file does not
901
  load the weights associated with the model, only the configuration. Check out the
902
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -908,7 +908,7 @@ QUIET_START_DOCSTRING = r"""
908
  QUIET_START_DOCSTRING,
909
  )
910
  class QuietPreTrainedModel(PreTrainedModel):
911
- config_class = AutoConfig
912
  base_model_prefix = "model"
913
  supports_gradient_checkpointing = True
914
  _no_split_modules = ["QuietDecoderLayer"]
@@ -995,10 +995,10 @@ class QuietModel(QuietPreTrainedModel):
995
  """
996
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
997
  Args:
998
- config: AutoConfig
999
  """
1000
 
1001
- def __init__(self, config: AutoConfig):
1002
  super().__init__(config)
1003
  self.padding_idx = config.pad_token_id
1004
  self.vocab_size = config.vocab_size
 
55
  logging,
56
  replace_return_docstrings,
57
  )
58
+ from .configuration_quiet import QuietConfig
59
 
60
 
61
  if is_flash_attn_2_available():
 
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
+ _CONFIG_FOR_DOC = "QuietConfig"
71
 
72
  from reportlab.pdfgen import canvas
73
  from reportlab.lib.pagesizes import letter
 
270
  and "Generating Long Sequences with Sparse Transformers".
271
  """
272
 
273
+ def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
 
818
 
819
 
820
  class QuietDecoderLayer(nn.Module):
821
+ def __init__(self, config: QuietConfig, layer_idx: int):
822
  super().__init__()
823
  self.hidden_size = config.hidden_size
824
 
 
896
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
897
  and behavior.
898
  Parameters:
899
+ config ([`QuietConfig`]):
900
  Model configuration class with all the parameters of the model. Initializing with a config file does not
901
  load the weights associated with the model, only the configuration. Check out the
902
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
908
  QUIET_START_DOCSTRING,
909
  )
910
  class QuietPreTrainedModel(PreTrainedModel):
911
+ config_class = QuietConfig
912
  base_model_prefix = "model"
913
  supports_gradient_checkpointing = True
914
  _no_split_modules = ["QuietDecoderLayer"]
 
995
  """
996
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
997
  Args:
998
+ config: QuietConfig
999
  """
1000
 
1001
+ def __init__(self, config: QuietConfig):
1002
  super().__init__(config)
1003
  self.padding_idx = config.pad_token_id
1004
  self.vocab_size = config.vocab_size