shu65 commited on
Commit
d406556
·
1 Parent(s): 65176a6

fix plamo model

Browse files
Files changed (1) hide show
  1. modeling_plamo.py +17 -27
modeling_plamo.py CHANGED
@@ -240,6 +240,8 @@ class PlamoCache(torch.nn.Module):
240
 
241
  def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
242
  c = self.cache[layer_idx]
 
 
243
  assert isinstance(c, PlamoAttentionCache)
244
 
245
  def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
@@ -257,11 +259,17 @@ class PlamoCache(torch.nn.Module):
257
  def update_attention(
258
  self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
259
  ) -> PlamoAttentionCache:
 
 
 
260
  if self.cache[layer_idx] is None:
261
- self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
 
 
 
 
 
262
  else:
263
- full_attn = layer_idx in self.config.full_attention_idx
264
- window_size = self.config.attention_window_size
265
  c = self.cache[layer_idx]
266
  assert isinstance(c, PlamoAttentionCache)
267
  k, v = self.append_kv(key_states, value_states, layer_idx)
@@ -968,15 +976,6 @@ class Attention(torch.nn.Module):
968
  query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
969
  key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
970
 
971
- if past_states is not None and past_states[self.layer_idx] is None:
972
- bsz, nhead_k, _, c_k = key_states.shape
973
- _, nhead_v, _, c_v = value_states.shape
974
- past_states.update_attention(
975
- torch.zeros((bsz, nhead_k, 0, c_k), dtype=key_states.dtype, device=key_states.device),
976
- torch.zeros((bsz, nhead_v, 0, c_v), dtype=value_states.dtype, device=value_states.device),
977
- self.layer_idx,
978
- )
979
-
980
  if past_states is not None:
981
  # reuse k, v, self_attention
982
  key_states_new = key_states
@@ -1154,6 +1153,7 @@ class PlamoDecoder(torch.nn.Module):
1154
  for i in range(config.num_hidden_layers)
1155
  ]
1156
  )
 
1157
 
1158
  def forward(self, x: DecoderInput) -> DecoderOutput:
1159
  all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
@@ -1166,19 +1166,12 @@ class PlamoDecoder(torch.nn.Module):
1166
  all_hidden_states += (hidden_states,)
1167
 
1168
  if self.training and x.gradient_checkpointing:
1169
-
1170
- def create_custom_forward(module): # type: ignore
1171
- def custom_forward(*inputs): # type: ignore
1172
- # None for past_key_value
1173
- return module(*inputs, x.output_attentions, None)
1174
-
1175
- return custom_forward
1176
-
1177
- layer_outputs = torch.utils.checkpoint.checkpoint(
1178
- create_custom_forward(decoder_layer), # type: ignore
1179
  hidden_states,
1180
  x.attention_mask,
1181
- None,
 
1182
  )
1183
  else:
1184
  layer_outputs = decoder_layer(
@@ -1217,9 +1210,6 @@ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1217
  if module.padding_idx is not None:
1218
  module.weight.data[module.padding_idx].zero_()
1219
 
1220
- def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
1221
- module.gradient_checkpointing = value # type: ignore
1222
-
1223
 
1224
  class PlamoModel(PlamoPreTrainedModel):
1225
  def __init__(self, config: PlamoConfig):
@@ -1613,4 +1603,4 @@ class Bias(nn.Module):
1613
  self,
1614
  x: torch.Tensor,
1615
  ) -> torch.Tensor:
1616
- return x + self._bias
 
240
 
241
  def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
242
  c = self.cache[layer_idx]
243
+ if c is None:
244
+ return key, value
245
  assert isinstance(c, PlamoAttentionCache)
246
 
247
  def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
 
259
  def update_attention(
260
  self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
261
  ) -> PlamoAttentionCache:
262
+ full_attn = layer_idx in self.config.full_attention_idx
263
+ window_size = self.config.attention_window_size
264
+
265
  if self.cache[layer_idx] is None:
266
+ if full_attn:
267
+ self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
268
+ else:
269
+ self.cache[layer_idx] = PlamoAttentionCache(
270
+ key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
271
+ )
272
  else:
 
 
273
  c = self.cache[layer_idx]
274
  assert isinstance(c, PlamoAttentionCache)
275
  k, v = self.append_kv(key_states, value_states, layer_idx)
 
976
  query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
977
  key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
978
 
 
 
 
 
 
 
 
 
 
979
  if past_states is not None:
980
  # reuse k, v, self_attention
981
  key_states_new = key_states
 
1153
  for i in range(config.num_hidden_layers)
1154
  ]
1155
  )
1156
+ self.gradient_checkpointing = False
1157
 
1158
  def forward(self, x: DecoderInput) -> DecoderOutput:
1159
  all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
 
1166
  all_hidden_states += (hidden_states,)
1167
 
1168
  if self.training and x.gradient_checkpointing:
1169
+ layer_outputs = self._gradient_checkpointing_func(
1170
+ decoder_layer.__call__,
 
 
 
 
 
 
 
 
1171
  hidden_states,
1172
  x.attention_mask,
1173
+ x.past_states,
1174
+ x.output_attentions,
1175
  )
1176
  else:
1177
  layer_outputs = decoder_layer(
 
1210
  if module.padding_idx is not None:
1211
  module.weight.data[module.padding_idx].zero_()
1212
 
 
 
 
1213
 
1214
  class PlamoModel(PlamoPreTrainedModel):
1215
  def __init__(self, config: PlamoConfig):
 
1603
  self,
1604
  x: torch.Tensor,
1605
  ) -> torch.Tensor:
1606
+ return x + self._bias