nicolinho commited on
Commit
eef8d85
·
verified ·
1 Parent(s): 7d50068

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +1 -1
modeling_custom.py CHANGED
@@ -195,7 +195,7 @@ class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
195
 
196
  # [B, num_objectives, num_quantiles, ]
197
  reward_quantiles = torch.mean(
198
- rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=2)
199
 
200
  rewards_expectation = rewards.mean(dim=2)
201
  score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True)
 
195
 
196
  # [B, num_objectives, num_quantiles, ]
197
  reward_quantiles = torch.mean(
198
+ rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=1)
199
 
200
  rewards_expectation = rewards.mean(dim=2)
201
  score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True)