date3k2 commited on
Commit
e6b0814
·
verified ·
1 Parent(s): 122bf22

Update hf_mamba_classification.py

Browse files
Files changed (1) hide show
  1. hf_mamba_classification.py +24 -70
hf_mamba_classification.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
  from torch import nn
3
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
  from transformers.models.mamba.modeling_mamba import (
5
- MambaPreTrainedModel,
6
  MambaModel,
7
- MambaCache,
8
  MAMBA_INPUTS_DOCSTRING,
9
  MAMBA_START_DOCSTRING,
10
  )
11
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast
12
  from typing import List, Optional, Tuple, Union
13
  from transformers.utils import (
14
  ModelOutput,
@@ -45,33 +44,21 @@ class MambaSequenceClassifierOutput(ModelOutput):
45
 
46
  loss: Optional[torch.FloatTensor] = None
47
  logits: torch.FloatTensor = None
48
- # cache_params: Optional[MambaCache] = None,
49
  cache_params: Optional[List[torch.FloatTensor]] = None
50
- # cache_params: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
51
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
52
-
53
-
54
  class MambaClassificationHead(nn.Module):
55
  """Head for sentence-level classification tasks."""
56
 
57
  def __init__(self, config):
58
  super().__init__()
59
- # self.activation = ACT2FN[config.hidden_act]
60
- # self.dense = nn.Linear(config.hidden_size, config.hidden_size)
61
- # self.dropout = nn.Dropout(config.hidden_dropout_prob)
62
  self.out_proj = nn.Linear(config.hidden_size, config.num_labels, bias=False)
63
-
64
- # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
65
  self.out_proj.weight.data.normal_(mean=0.0, std=config.initializer_range)
66
 
67
  self.config = config
68
 
69
  def forward(self, features, **kwargs):
70
- # x = features[:, 0, :] # take <s> token (equiv. to [CLS])
71
- # x = self.dropout(x)
72
- # x = self.dense(x)
73
- # x = self.activation(x)
74
- # x = self.dropout(x)
75
  x = features
76
  x = self.out_proj(x)
77
  return x
@@ -86,19 +73,15 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
86
  def __init__(self, config):
87
  super().__init__(config)
88
  self.num_labels = config.num_labels
89
- # self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
90
  self.backbone = MambaModel(config)
91
- # self.classifier = MambaClassificationHead(config)
92
- self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=False)
93
- # self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
94
-
95
- for param in self.base_model.parameters():
96
- param.requires_grad = False
97
 
98
  # Initialize weights and apply final processing
99
  self.post_init()
100
 
101
- @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
 
 
102
  @add_code_sample_docstrings(
103
  checkpoint=_CHECKPOINT_FOR_DOC,
104
  output_type=MambaSequenceClassifierOutput,
@@ -122,19 +105,9 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
122
  If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
123
  If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
124
  """
125
- # use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
126
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
127
-
128
- # if inputs_embeds is None:
129
- # inputs_embeds = self.backbone.embeddings(input_ids)
130
-
131
- # if self.backbone.gradient_checkpointing and self.training and use_cache:
132
- # use_cache = False
133
-
134
- # if cache_params is None and use_cache:
135
- # cache_params = MambaCache(
136
- # self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
137
- # )
138
 
139
  mamba_outputs = self.backbone(
140
  input_ids,
@@ -154,13 +127,15 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
154
  assert (
155
  self.config.pad_token_id is not None or batch_size == 1
156
  ), "Cannot handle batch sizes > 1 if no padding token is defined."
157
-
158
  if self.config.pad_token_id is None:
159
  sequence_lengths = -1
160
  else:
161
  if input_ids is not None:
162
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
163
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
164
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
165
  sequence_lengths = sequence_lengths.to(logits.device)
166
  else:
@@ -170,34 +145,13 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
170
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
171
  )
172
 
173
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
174
-
175
- loss = None
176
- if labels is not None:
177
- if self.config.problem_type is None:
178
- if self.num_labels == 1:
179
- self.config.problem_type = "regression"
180
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
181
- self.config.problem_type = "single_label_classification"
182
- else:
183
- self.config.problem_type = "multi_label_classification"
184
-
185
- if self.config.problem_type == "regression":
186
- loss_fct = MSELoss()
187
- if self.num_labels == 1:
188
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
189
- else:
190
- loss = loss_fct(pooled_logits, labels)
191
- elif self.config.problem_type == "single_label_classification":
192
- loss_fct = CrossEntropyLoss()
193
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
194
- elif self.config.problem_type == "multi_label_classification":
195
- loss_fct = BCEWithLogitsLoss()
196
- loss = loss_fct(pooled_logits, labels)
197
-
198
- # if use_cache:
199
- # cache_params.seqlen_offset += inputs_embeds.shape[1]
200
-
201
  if not return_dict:
202
  output = (pooled_logits,) + mamba_outputs[1:]
203
  return ((loss,) + output) if loss is not None else output
@@ -207,4 +161,4 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
207
  logits=pooled_logits,
208
  cache_params=mamba_outputs.cache_params,
209
  hidden_states=mamba_outputs.hidden_states,
210
- )
 
1
  import torch
2
  from torch import nn
3
+ from torch.nn import CrossEntropyLoss
4
  from transformers.models.mamba.modeling_mamba import (
5
+ MambaPreTrainedModel,
6
  MambaModel,
7
+ MambaCache,
8
  MAMBA_INPUTS_DOCSTRING,
9
  MAMBA_START_DOCSTRING,
10
  )
 
11
  from typing import List, Optional, Tuple, Union
12
  from transformers.utils import (
13
  ModelOutput,
 
44
 
45
  loss: Optional[torch.FloatTensor] = None
46
  logits: torch.FloatTensor = None
 
47
  cache_params: Optional[List[torch.FloatTensor]] = None
 
48
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
49
+
50
+
51
  class MambaClassificationHead(nn.Module):
52
  """Head for sentence-level classification tasks."""
53
 
54
  def __init__(self, config):
55
  super().__init__()
 
 
 
56
  self.out_proj = nn.Linear(config.hidden_size, config.num_labels, bias=False)
 
 
57
  self.out_proj.weight.data.normal_(mean=0.0, std=config.initializer_range)
58
 
59
  self.config = config
60
 
61
  def forward(self, features, **kwargs):
 
 
 
 
 
62
  x = features
63
  x = self.out_proj(x)
64
  return x
 
73
  def __init__(self, config):
74
  super().__init__(config)
75
  self.num_labels = config.num_labels
 
76
  self.backbone = MambaModel(config)
77
+ self.classifier = MambaClassificationHead(config)
 
 
 
 
 
78
 
79
  # Initialize weights and apply final processing
80
  self.post_init()
81
 
82
+ @add_start_docstrings_to_model_forward(
83
+ MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")
84
+ )
85
  @add_code_sample_docstrings(
86
  checkpoint=_CHECKPOINT_FOR_DOC,
87
  output_type=MambaSequenceClassifierOutput,
 
105
  If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
106
  If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
107
  """
108
+ return_dict = (
109
+ return_dict if return_dict is not None else self.config.use_return_dict
110
+ )
 
 
 
 
 
 
 
 
 
 
111
 
112
  mamba_outputs = self.backbone(
113
  input_ids,
 
127
  assert (
128
  self.config.pad_token_id is not None or batch_size == 1
129
  ), "Cannot handle batch sizes > 1 if no padding token is defined."
130
+
131
  if self.config.pad_token_id is None:
132
  sequence_lengths = -1
133
  else:
134
  if input_ids is not None:
135
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
136
+ sequence_lengths = (
137
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
138
+ )
139
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
140
  sequence_lengths = sequence_lengths.to(logits.device)
141
  else:
 
145
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
146
  )
147
 
148
+ pooled_logits = logits[
149
+ torch.arange(batch_size, device=logits.device), sequence_lengths
150
+ ]
151
+
152
+ loss_fct = CrossEntropyLoss()
153
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
154
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if not return_dict:
156
  output = (pooled_logits,) + mamba_outputs[1:]
157
  return ((loss,) + output) if loss is not None else output
 
161
  logits=pooled_logits,
162
  cache_params=mamba_outputs.cache_params,
163
  hidden_states=mamba_outputs.hidden_states,
164
+ )