Lolalb commited on
Commit
99d97ed
·
verified ·
1 Parent(s): 2650f1f

Upload AMPLIFY

Browse files
Files changed (1) hide show
  1. amplify.py +10 -2
amplify.py CHANGED
@@ -295,6 +295,14 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
295
  # Initialize
296
  hidden_states, attentions = [], []
297
 
 
 
 
 
 
 
 
 
298
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
299
  if pad_mask is not None:
300
  pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
@@ -325,9 +333,9 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
325
  x = self.layer_norm_1(x)
326
 
327
  # Transformer encoder
328
- for layer in self.transformer_encoder:
329
  x, attn = layer(x, pad_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
330
- if output_hidden_states:
331
  hidden_states.append(x)
332
  if output_attentions:
333
  attentions.append(attn)
 
295
  # Initialize
296
  hidden_states, attentions = [], []
297
 
298
+ # We will output all the hidden_states that have an index higher than output_hidden_index
299
+ if type(output_hidden_states) == bool and not output_hidden_states:
300
+ output_hidden_index = self.config.num_hidden_layers + 1
301
+ elif type(output_hidden_states) == int:
302
+ output_hidden_index = output_hidden_states
303
+ else:
304
+ output_hidden_index = 0
305
+
306
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
307
  if pad_mask is not None:
308
  pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
 
333
  x = self.layer_norm_1(x)
334
 
335
  # Transformer encoder
336
+ for idx, layer in enumerate(self.transformer_encoder):
337
  x, attn = layer(x, pad_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
338
+ if idx >= output_hidden_index:
339
  hidden_states.append(x)
340
  if output_attentions:
341
  attentions.append(attn)