lhallee commited on
Commit
f5b811e
·
verified ·
1 Parent(s): dbd5e23

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py CHANGED
@@ -49,6 +49,7 @@ class ESMplusplusConfig(PretrainedConfig):
49
  num_labels: int = 2,
50
  problem_type: str | None = None,
51
  dropout: float = 0.0,
 
52
  **kwargs,
53
  ):
54
  super().__init__(**kwargs)
@@ -59,6 +60,7 @@ class ESMplusplusConfig(PretrainedConfig):
59
  self.num_labels = num_labels
60
  self.problem_type = problem_type
61
  self.dropout = dropout
 
62
 
63
 
64
  ### Rotary Embeddings
@@ -792,7 +794,7 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
792
  )
793
 
794
 
795
- class ESMplusplusForSequenceClassification(PreTrainedESMplusplusModel):
796
  """
797
  ESM++ model for sequence classification.
798
  Extends the base ESM++ model with a classification head.
@@ -873,7 +875,7 @@ class ESMplusplusForSequenceClassification(PreTrainedESMplusplusModel):
873
  )
874
 
875
 
876
- class ESMplusplusForTokenClassification(PreTrainedESMplusplusModel):
877
  """
878
  ESM++ model for token classification.
879
  Extends the base ESM++ model with a token classification head.
 
49
  num_labels: int = 2,
50
  problem_type: str | None = None,
51
  dropout: float = 0.0,
52
+ initializer_range: float = 0.02,
53
  **kwargs,
54
  ):
55
  super().__init__(**kwargs)
 
60
  self.num_labels = num_labels
61
  self.problem_type = problem_type
62
  self.dropout = dropout
63
+ self.initializer_range = initializer_range
64
 
65
 
66
  ### Rotary Embeddings
 
794
  )
795
 
796
 
797
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
798
  """
799
  ESM++ model for sequence classification.
800
  Extends the base ESM++ model with a classification head.
 
875
  )
876
 
877
 
878
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
879
  """
880
  ESM++ model for token classification.
881
  Extends the base ESM++ model with a token classification head.