date3k2 commited on
Commit
122bf22
·
verified ·
1 Parent(s): 064f39e

Upload hf_mamba_classification.py

Browse files
Files changed (1) hide show
  1. hf_mamba_classification.py +210 -0
hf_mamba_classification.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
+ from transformers.models.mamba.modeling_mamba import (
5
+ MambaPreTrainedModel,
6
+ MambaModel,
7
+ MambaCache,
8
+ MAMBA_INPUTS_DOCSTRING,
9
+ MAMBA_START_DOCSTRING,
10
+ )
11
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
12
+ from typing import List, Optional, Tuple, Union
13
+ from transformers.utils import (
14
+ ModelOutput,
15
+ add_start_docstrings,
16
+ add_start_docstrings_to_model_forward,
17
+ add_code_sample_docstrings,
18
+ )
19
+ from dataclasses import dataclass
20
+
21
+
22
+ _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
23
+ _CONFIG_FOR_DOC = "MambaConfig"
24
+
25
+
26
+ @dataclass
27
+ class MambaSequenceClassifierOutput(ModelOutput):
28
+ """
29
+ Base class for outputs of sentence classification models.
30
+
31
+ Args:
32
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
33
+ Classification (or regression if config.num_labels==1) loss.
34
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
35
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
36
+ cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
37
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
38
+ avoid providing the old `input_ids`.
39
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
40
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
41
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
42
+
43
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
44
+ """
45
+
46
+ loss: Optional[torch.FloatTensor] = None
47
+ logits: torch.FloatTensor = None
48
+ # cache_params: Optional[MambaCache] = None,
49
+ cache_params: Optional[List[torch.FloatTensor]] = None
50
+ # cache_params: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
51
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
52
+
53
+
54
+ class MambaClassificationHead(nn.Module):
55
+ """Head for sentence-level classification tasks."""
56
+
57
+ def __init__(self, config):
58
+ super().__init__()
59
+ # self.activation = ACT2FN[config.hidden_act]
60
+ # self.dense = nn.Linear(config.hidden_size, config.hidden_size)
61
+ # self.dropout = nn.Dropout(config.hidden_dropout_prob)
62
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels, bias=False)
63
+
64
+ # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
65
+ self.out_proj.weight.data.normal_(mean=0.0, std=config.initializer_range)
66
+
67
+ self.config = config
68
+
69
+ def forward(self, features, **kwargs):
70
+ # x = features[:, 0, :] # take <s> token (equiv. to [CLS])
71
+ # x = self.dropout(x)
72
+ # x = self.dense(x)
73
+ # x = self.activation(x)
74
+ # x = self.dropout(x)
75
+ x = features
76
+ x = self.out_proj(x)
77
+ return x
78
+
79
+
80
+ @add_start_docstrings(
81
+ """Mamba Model backbone with a sequence classification/regression head on top (a linear layer on top of
82
+ the pooled output) e.g. for GLUE tasks.""",
83
+ MAMBA_START_DOCSTRING,
84
+ )
85
+ class MambaForSequenceClassification(MambaPreTrainedModel):
86
+ def __init__(self, config):
87
+ super().__init__(config)
88
+ self.num_labels = config.num_labels
89
+ # self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
90
+ self.backbone = MambaModel(config)
91
+ # self.classifier = MambaClassificationHead(config)
92
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=False)
93
+ # self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
94
+
95
+ for param in self.base_model.parameters():
96
+ param.requires_grad = False
97
+
98
+ # Initialize weights and apply final processing
99
+ self.post_init()
100
+
101
+ @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
102
+ @add_code_sample_docstrings(
103
+ checkpoint=_CHECKPOINT_FOR_DOC,
104
+ output_type=MambaSequenceClassifierOutput,
105
+ config_class=_CONFIG_FOR_DOC,
106
+ )
107
+ def forward(
108
+ self,
109
+ input_ids: Optional[torch.LongTensor] = None,
110
+ inputs_embeds: Optional[torch.FloatTensor] = None,
111
+ cache_params: Optional[MambaCache] = None,
112
+ use_cache: Optional[bool] = None,
113
+ labels: Optional[torch.LongTensor] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ **kwargs,
117
+ ) -> Union[Tuple, MambaSequenceClassifierOutput]:
118
+ r"""
119
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
120
+ Labels for computing the sequence classification/regression loss.
121
+ Indices should be in `[0, ..., config.num_labels - 1]`.
122
+ If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
123
+ If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
124
+ """
125
+ # use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
127
+
128
+ # if inputs_embeds is None:
129
+ # inputs_embeds = self.backbone.embeddings(input_ids)
130
+
131
+ # if self.backbone.gradient_checkpointing and self.training and use_cache:
132
+ # use_cache = False
133
+
134
+ # if cache_params is None and use_cache:
135
+ # cache_params = MambaCache(
136
+ # self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
137
+ # )
138
+
139
+ mamba_outputs = self.backbone(
140
+ input_ids,
141
+ cache_params=cache_params,
142
+ use_cache=use_cache,
143
+ inputs_embeds=inputs_embeds,
144
+ output_hidden_states=output_hidden_states,
145
+ return_dict=return_dict,
146
+ )
147
+ hidden_states = mamba_outputs[0]
148
+ logits = self.classifier(hidden_states)
149
+
150
+ if input_ids is not None:
151
+ batch_size, sequence_length = input_ids.shape[:2]
152
+ else:
153
+ batch_size, sequence_length = inputs_embeds.shape[:2]
154
+ assert (
155
+ self.config.pad_token_id is not None or batch_size == 1
156
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
157
+
158
+ if self.config.pad_token_id is None:
159
+ sequence_lengths = -1
160
+ else:
161
+ if input_ids is not None:
162
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
163
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
164
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
165
+ sequence_lengths = sequence_lengths.to(logits.device)
166
+ else:
167
+ sequence_lengths = -1
168
+ print(
169
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
170
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
171
+ )
172
+
173
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
174
+
175
+ loss = None
176
+ if labels is not None:
177
+ if self.config.problem_type is None:
178
+ if self.num_labels == 1:
179
+ self.config.problem_type = "regression"
180
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
181
+ self.config.problem_type = "single_label_classification"
182
+ else:
183
+ self.config.problem_type = "multi_label_classification"
184
+
185
+ if self.config.problem_type == "regression":
186
+ loss_fct = MSELoss()
187
+ if self.num_labels == 1:
188
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
189
+ else:
190
+ loss = loss_fct(pooled_logits, labels)
191
+ elif self.config.problem_type == "single_label_classification":
192
+ loss_fct = CrossEntropyLoss()
193
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
194
+ elif self.config.problem_type == "multi_label_classification":
195
+ loss_fct = BCEWithLogitsLoss()
196
+ loss = loss_fct(pooled_logits, labels)
197
+
198
+ # if use_cache:
199
+ # cache_params.seqlen_offset += inputs_embeds.shape[1]
200
+
201
+ if not return_dict:
202
+ output = (pooled_logits,) + mamba_outputs[1:]
203
+ return ((loss,) + output) if loss is not None else output
204
+
205
+ return MambaSequenceClassifierOutput(
206
+ loss=loss,
207
+ logits=pooled_logits,
208
+ cache_params=mamba_outputs.cache_params,
209
+ hidden_states=mamba_outputs.hidden_states,
210
+ )