fix: changes
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -1616,6 +1616,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1616 |
if config.activation_function == "glu"
|
1617 |
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
1618 |
)
|
|
|
1619 |
if moe:
|
1620 |
if dmoe is not None:
|
1621 |
megablocks_args = Arguments(
|
@@ -1702,7 +1703,10 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1702 |
dropped = self.dropout2(hidden_states)
|
1703 |
residual = (dropped + residual) if residual is not None else dropped
|
1704 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
1705 |
-
|
|
|
|
|
|
|
1706 |
|
1707 |
return hidden_states, None, residual
|
1708 |
else:
|
@@ -1716,7 +1720,10 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1716 |
rope=rope,
|
1717 |
)
|
1718 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
1719 |
-
|
|
|
|
|
|
|
1720 |
|
1721 |
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
|
1722 |
return hidden_states, None, None
|
|
|
1616 |
if config.activation_function == "glu"
|
1617 |
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
1618 |
)
|
1619 |
+
self.moe = moe
|
1620 |
if moe:
|
1621 |
if dmoe is not None:
|
1622 |
megablocks_args = Arguments(
|
|
|
1703 |
dropped = self.dropout2(hidden_states)
|
1704 |
residual = (dropped + residual) if residual is not None else dropped
|
1705 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
1706 |
+
if self.moe:
|
1707 |
+
hidden_states = self.mlp(hidden_states, attention_mask)
|
1708 |
+
else:
|
1709 |
+
hidden_states = self.mlp(hidden_states)
|
1710 |
|
1711 |
return hidden_states, None, residual
|
1712 |
else:
|
|
|
1720 |
rope=rope,
|
1721 |
)
|
1722 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
1723 |
+
if self.moe:
|
1724 |
+
mlp_out = self.mlp(hidden_states, attention_mask)
|
1725 |
+
else:
|
1726 |
+
mlp_out = self.mlp(hidden_states)
|
1727 |
|
1728 |
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
|
1729 |
return hidden_states, None, None
|