fix modeling_plamo.py (#1)
Browse files- fix plamo model (d406556e08aa20a112b0331fbb8ee95df615ab35)
Co-authored-by: Shuji Suzuki <[email protected]>
- modeling_plamo.py +17 -27
modeling_plamo.py
CHANGED
@@ -240,6 +240,8 @@ class PlamoCache(torch.nn.Module):
|
|
240 |
|
241 |
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
242 |
c = self.cache[layer_idx]
|
|
|
|
|
243 |
assert isinstance(c, PlamoAttentionCache)
|
244 |
|
245 |
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
@@ -257,11 +259,17 @@ class PlamoCache(torch.nn.Module):
|
|
257 |
def update_attention(
|
258 |
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
259 |
) -> PlamoAttentionCache:
|
|
|
|
|
|
|
260 |
if self.cache[layer_idx] is None:
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
262 |
else:
|
263 |
-
full_attn = layer_idx in self.config.full_attention_idx
|
264 |
-
window_size = self.config.attention_window_size
|
265 |
c = self.cache[layer_idx]
|
266 |
assert isinstance(c, PlamoAttentionCache)
|
267 |
k, v = self.append_kv(key_states, value_states, layer_idx)
|
@@ -968,15 +976,6 @@ class Attention(torch.nn.Module):
|
|
968 |
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
969 |
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
970 |
|
971 |
-
if past_states is not None and past_states[self.layer_idx] is None:
|
972 |
-
bsz, nhead_k, _, c_k = key_states.shape
|
973 |
-
_, nhead_v, _, c_v = value_states.shape
|
974 |
-
past_states.update_attention(
|
975 |
-
torch.zeros((bsz, nhead_k, 0, c_k), dtype=key_states.dtype, device=key_states.device),
|
976 |
-
torch.zeros((bsz, nhead_v, 0, c_v), dtype=value_states.dtype, device=value_states.device),
|
977 |
-
self.layer_idx,
|
978 |
-
)
|
979 |
-
|
980 |
if past_states is not None:
|
981 |
# reuse k, v, self_attention
|
982 |
key_states_new = key_states
|
@@ -1154,6 +1153,7 @@ class PlamoDecoder(torch.nn.Module):
|
|
1154 |
for i in range(config.num_hidden_layers)
|
1155 |
]
|
1156 |
)
|
|
|
1157 |
|
1158 |
def forward(self, x: DecoderInput) -> DecoderOutput:
|
1159 |
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
@@ -1166,19 +1166,12 @@ class PlamoDecoder(torch.nn.Module):
|
|
1166 |
all_hidden_states += (hidden_states,)
|
1167 |
|
1168 |
if self.training and x.gradient_checkpointing:
|
1169 |
-
|
1170 |
-
|
1171 |
-
def custom_forward(*inputs): # type: ignore
|
1172 |
-
# None for past_key_value
|
1173 |
-
return module(*inputs, x.output_attentions, None)
|
1174 |
-
|
1175 |
-
return custom_forward
|
1176 |
-
|
1177 |
-
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1178 |
-
create_custom_forward(decoder_layer), # type: ignore
|
1179 |
hidden_states,
|
1180 |
x.attention_mask,
|
1181 |
-
|
|
|
1182 |
)
|
1183 |
else:
|
1184 |
layer_outputs = decoder_layer(
|
@@ -1217,9 +1210,6 @@ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
|
|
1217 |
if module.padding_idx is not None:
|
1218 |
module.weight.data[module.padding_idx].zero_()
|
1219 |
|
1220 |
-
def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
|
1221 |
-
module.gradient_checkpointing = value # type: ignore
|
1222 |
-
|
1223 |
|
1224 |
class PlamoModel(PlamoPreTrainedModel):
|
1225 |
def __init__(self, config: PlamoConfig):
|
@@ -1613,4 +1603,4 @@ class Bias(nn.Module):
|
|
1613 |
self,
|
1614 |
x: torch.Tensor,
|
1615 |
) -> torch.Tensor:
|
1616 |
-
return x + self._bias
|
|
|
240 |
|
241 |
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
242 |
c = self.cache[layer_idx]
|
243 |
+
if c is None:
|
244 |
+
return key, value
|
245 |
assert isinstance(c, PlamoAttentionCache)
|
246 |
|
247 |
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
|
|
259 |
def update_attention(
|
260 |
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
261 |
) -> PlamoAttentionCache:
|
262 |
+
full_attn = layer_idx in self.config.full_attention_idx
|
263 |
+
window_size = self.config.attention_window_size
|
264 |
+
|
265 |
if self.cache[layer_idx] is None:
|
266 |
+
if full_attn:
|
267 |
+
self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
|
268 |
+
else:
|
269 |
+
self.cache[layer_idx] = PlamoAttentionCache(
|
270 |
+
key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
|
271 |
+
)
|
272 |
else:
|
|
|
|
|
273 |
c = self.cache[layer_idx]
|
274 |
assert isinstance(c, PlamoAttentionCache)
|
275 |
k, v = self.append_kv(key_states, value_states, layer_idx)
|
|
|
976 |
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
977 |
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
if past_states is not None:
|
980 |
# reuse k, v, self_attention
|
981 |
key_states_new = key_states
|
|
|
1153 |
for i in range(config.num_hidden_layers)
|
1154 |
]
|
1155 |
)
|
1156 |
+
self.gradient_checkpointing = False
|
1157 |
|
1158 |
def forward(self, x: DecoderInput) -> DecoderOutput:
|
1159 |
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
|
|
1166 |
all_hidden_states += (hidden_states,)
|
1167 |
|
1168 |
if self.training and x.gradient_checkpointing:
|
1169 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1170 |
+
decoder_layer.__call__,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1171 |
hidden_states,
|
1172 |
x.attention_mask,
|
1173 |
+
x.past_states,
|
1174 |
+
x.output_attentions,
|
1175 |
)
|
1176 |
else:
|
1177 |
layer_outputs = decoder_layer(
|
|
|
1210 |
if module.padding_idx is not None:
|
1211 |
module.weight.data[module.padding_idx].zero_()
|
1212 |
|
|
|
|
|
|
|
1213 |
|
1214 |
class PlamoModel(PlamoPreTrainedModel):
|
1215 |
def __init__(self, config: PlamoConfig):
|
|
|
1603 |
self,
|
1604 |
x: torch.Tensor,
|
1605 |
) -> torch.Tensor:
|
1606 |
+
return x + self._bias
|