Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +14 -13
raven_modeling_minimal.py
CHANGED
|
@@ -23,6 +23,16 @@ import torch.nn.functional as F
|
|
| 23 |
from transformers import GenerationConfig
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class RavenPreTrainedModel(PreTrainedModel):
|
| 27 |
config_class = RavenConfig
|
| 28 |
base_model_prefix = "model"
|
|
@@ -37,18 +47,9 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
| 37 |
_supports_static_cache = True
|
| 38 |
_tp_plan = {}
|
| 39 |
|
| 40 |
-
@cache
|
| 41 |
-
def _init_func(self, dim, num_layers):
|
| 42 |
-
return {
|
| 43 |
-
"std": math.sqrt(2 / (5 * dim)),
|
| 44 |
-
"out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
|
| 45 |
-
"embedding": math.sqrt(2 / (5 * dim)),
|
| 46 |
-
"embed_scale": math.sqrt(dim),
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
@property
|
| 50 |
def emb_scale(self):
|
| 51 |
-
return
|
| 52 |
|
| 53 |
def _normal_(self, tensor, std):
|
| 54 |
return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
|
@@ -86,7 +87,7 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
| 86 |
|
| 87 |
@torch.no_grad()
|
| 88 |
def _init_weights(self, module):
|
| 89 |
-
_init_values =
|
| 90 |
name = self._full_name_of_module_lookup[id(module)]
|
| 91 |
if isinstance(module, RMSNorm):
|
| 92 |
torch.nn.init.ones_(module.weight)
|
|
@@ -703,14 +704,14 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
| 703 |
loss = torch.nn.functional.cross_entropy(
|
| 704 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
|
| 705 |
)
|
| 706 |
-
log_ppl = loss.clone().detach()
|
| 707 |
else:
|
| 708 |
logits = self.lm_head(x).float()
|
| 709 |
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
|
| 710 |
|
| 711 |
return CausalLMOutputRecurrentLatents(
|
| 712 |
loss=loss,
|
| 713 |
-
log_ppl=log_ppl,
|
| 714 |
logits=logits if output_details["return_logits"] else None,
|
| 715 |
past_key_values=past_key_values,
|
| 716 |
hidden_states=x if output_details["return_head"] else None,
|
|
|
|
| 23 |
from transformers import GenerationConfig
|
| 24 |
|
| 25 |
|
| 26 |
+
@cache
|
| 27 |
+
def _init_func(dim, num_layers) -> dict[str, float]:
|
| 28 |
+
return {
|
| 29 |
+
"std": math.sqrt(2 / (5 * dim)),
|
| 30 |
+
"out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
|
| 31 |
+
"embedding": math.sqrt(2 / (5 * dim)),
|
| 32 |
+
"embed_scale": math.sqrt(dim),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class RavenPreTrainedModel(PreTrainedModel):
|
| 37 |
config_class = RavenConfig
|
| 38 |
base_model_prefix = "model"
|
|
|
|
| 47 |
_supports_static_cache = True
|
| 48 |
_tp_plan = {}
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@property
|
| 51 |
def emb_scale(self):
|
| 52 |
+
return _init_func(self.config.n_embd, self.config.effective_expected_depth)["embed_scale"]
|
| 53 |
|
| 54 |
def _normal_(self, tensor, std):
|
| 55 |
return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
|
|
|
| 87 |
|
| 88 |
@torch.no_grad()
|
| 89 |
def _init_weights(self, module):
|
| 90 |
+
_init_values = _init_func(self.config.n_embd, self.config.effective_expected_depth)
|
| 91 |
name = self._full_name_of_module_lookup[id(module)]
|
| 92 |
if isinstance(module, RMSNorm):
|
| 93 |
torch.nn.init.ones_(module.weight)
|
|
|
|
| 704 |
loss = torch.nn.functional.cross_entropy(
|
| 705 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
|
| 706 |
)
|
| 707 |
+
log_ppl = loss.clone().detach()
|
| 708 |
else:
|
| 709 |
logits = self.lm_head(x).float()
|
| 710 |
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
|
| 711 |
|
| 712 |
return CausalLMOutputRecurrentLatents(
|
| 713 |
loss=loss,
|
| 714 |
+
log_ppl=log_ppl, # this value is returned only for compatibility reasons. For this model loss=log-ppl
|
| 715 |
logits=logits if output_details["return_logits"] else None,
|
| 716 |
past_key_values=past_key_values,
|
| 717 |
hidden_states=x if output_details["return_head"] else None,
|