Upload AMPLIFY
Browse files- amplify.py +10 -2
amplify.py
CHANGED
@@ -295,6 +295,14 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
295 |
# Initialize
|
296 |
hidden_states, attentions = [], []
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
299 |
if pad_mask is not None:
|
300 |
pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
|
@@ -325,9 +333,9 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
325 |
x = self.layer_norm_1(x)
|
326 |
|
327 |
# Transformer encoder
|
328 |
-
for layer in self.transformer_encoder:
|
329 |
x, attn = layer(x, pad_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
|
330 |
-
if
|
331 |
hidden_states.append(x)
|
332 |
if output_attentions:
|
333 |
attentions.append(attn)
|
|
|
295 |
# Initialize
|
296 |
hidden_states, attentions = [], []
|
297 |
|
298 |
+
# We will output all the hidden_states that have an index higher than output_hidden_index
|
299 |
+
if type(output_hidden_states) == bool and not output_hidden_states:
|
300 |
+
output_hidden_index = self.config.num_hidden_layers + 1
|
301 |
+
elif type(output_hidden_states) == int:
|
302 |
+
output_hidden_index = output_hidden_states
|
303 |
+
else:
|
304 |
+
output_hidden_index = 0
|
305 |
+
|
306 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
307 |
if pad_mask is not None:
|
308 |
pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
|
|
|
333 |
x = self.layer_norm_1(x)
|
334 |
|
335 |
# Transformer encoder
|
336 |
+
for idx, layer in enumerate(self.transformer_encoder):
|
337 |
x, attn = layer(x, pad_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
|
338 |
+
if idx >= output_hidden_index:
|
339 |
hidden_states.append(x)
|
340 |
if output_attentions:
|
341 |
attentions.append(attn)
|