Update modeling_custom.py
Browse files- modeling_custom.py +1 -1
modeling_custom.py
CHANGED
@@ -195,7 +195,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
|
|
195 |
|
196 |
# [B, num_objectives, num_quantiles, ]
|
197 |
reward_quantiles = torch.mean(
|
198 |
-
rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=
|
199 |
|
200 |
rewards_expectation = rewards.mean(dim=2)
|
201 |
score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True)
|
|
|
195 |
|
196 |
# [B, num_objectives, num_quantiles, ]
|
197 |
reward_quantiles = torch.mean(
|
198 |
+
rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=1)
|
199 |
|
200 |
rewards_expectation = rewards.mean(dim=2)
|
201 |
score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True)
|