Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| class GPTRewardModel(nn.Module): | |
| def __init__(self, model_path): | |
| super().__init__() | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| self.config = model.config | |
| # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` | |
| self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd | |
| self.transformer = model.transformer | |
| self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) | |
| self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] | |
| def forward( | |
| self, | |
| input_ids=None, | |
| past_key_values=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| mc_token_ids=None, | |
| labels=None, | |
| return_dict=False, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ): | |
| loss = None | |
| transformer_outputs = self.transformer( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| ) | |
| hidden_states = transformer_outputs[0] | |
| rewards = self.v_head(hidden_states).squeeze(-1) | |
| chosen_end_scores = [] | |
| rejected_end_scores = [] | |
| # Split the inputs and rewards into two parts, chosen and rejected | |
| assert len(input_ids.shape) == 2 | |
| bs = input_ids.shape[0] // 2 | |
| chosen = input_ids[:bs] | |
| rejected = input_ids[bs:] | |
| chosen_rewards = rewards[:bs] | |
| rejected_rewards = rewards[bs:] | |
| loss = 0 | |
| inference = False | |
| for i in range(bs): | |
| if torch.all(torch.eq(chosen[i], rejected[i])).item(): | |
| c_inds = (chosen[i] == self.PAD_ID).nonzero() | |
| c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] | |
| chosen_end_scores.append(chosen_rewards[i, c_ind - 1]) | |
| inference = True | |
| continue | |
| # Check if there is any padding otherwise take length of sequence | |
| c_inds = (chosen[i] == self.PAD_ID).nonzero() | |
| c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] | |
| r_inds = (rejected[i] == self.PAD_ID).nonzero() | |
| r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1] | |
| end_ind = max(c_ind, r_ind) | |
| # Retrieve first index where trajectories diverge | |
| divergence_ind = (chosen[i] != rejected[i]).nonzero()[0] | |
| assert divergence_ind > 0 | |
| # Index into the correct rewards | |
| c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind] | |
| r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind] | |
| # Append the last rewards to the list of end scores | |
| chosen_end_scores.append(c_truncated_reward[-1]) | |
| rejected_end_scores.append(r_truncated_reward[-1]) | |
| # Compute loss based on truncated rewards (ignore padding) | |
| loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() | |
| loss = loss / bs | |
| if not inference: | |
| chosen_end_scores = torch.stack(chosen_end_scores) | |
| rejected_end_scores = torch.stack(rejected_end_scores) | |
| if inference: | |
| chosen_end_scores = torch.stack(chosen_end_scores) | |
| return {"chosen_end_scores": chosen_end_scores} | |
| return { | |
| "loss": loss, | |
| "chosen_end_scores": chosen_end_scores, | |
| "rejected_end_scores": rejected_end_scores, | |
| } | |