Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
f169ad7
·
1 Parent(s): 9f26eba

fix: changes

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +9 -2
modeling_hf_nomic_bert.py CHANGED
@@ -1616,6 +1616,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1616
  if config.activation_function == "glu"
1617
  else (F.silu if config.activation_function == "swiglu" else F.gelu)
1618
  )
 
1619
  if moe:
1620
  if dmoe is not None:
1621
  megablocks_args = Arguments(
@@ -1702,7 +1703,10 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1702
  dropped = self.dropout2(hidden_states)
1703
  residual = (dropped + residual) if residual is not None else dropped
1704
  hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1705
- hidden_states = self.mlp(hidden_states)
 
 
 
1706
 
1707
  return hidden_states, None, residual
1708
  else:
@@ -1716,7 +1720,10 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1716
  rope=rope,
1717
  )
1718
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1719
- mlp_out = self.mlp(hidden_states)
 
 
 
1720
 
1721
  hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
1722
  return hidden_states, None, None
 
1616
  if config.activation_function == "glu"
1617
  else (F.silu if config.activation_function == "swiglu" else F.gelu)
1618
  )
1619
+ self.moe = moe
1620
  if moe:
1621
  if dmoe is not None:
1622
  megablocks_args = Arguments(
 
1703
  dropped = self.dropout2(hidden_states)
1704
  residual = (dropped + residual) if residual is not None else dropped
1705
  hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1706
+ if self.moe:
1707
+ hidden_states = self.mlp(hidden_states, attention_mask)
1708
+ else:
1709
+ hidden_states = self.mlp(hidden_states)
1710
 
1711
  return hidden_states, None, residual
1712
  else:
 
1720
  rope=rope,
1721
  )
1722
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1723
+ if self.moe:
1724
+ mlp_out = self.mlp(hidden_states, attention_mask)
1725
+ else:
1726
+ mlp_out = self.mlp(hidden_states)
1727
 
1728
  hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
1729
  return hidden_states, None, None