Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| import torch | |
| from datasets import load_dataset | |
| from reward_model import GPTRewardModel | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| def set_seed(seed_val=42): | |
| random.seed(seed_val) | |
| np.random.seed(seed_val) | |
| torch.manual_seed(seed_val) | |
| torch.cuda.manual_seed_all(seed_val) | |
| def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): | |
| dataset = load_dataset(path, split=split) | |
| if split == "test": | |
| dataset = dataset.select(range(5000)) | |
| pairs = [] | |
| for sample in tqdm(dataset): | |
| pair = {} | |
| prompt = sample["prompt"] | |
| chosen_summary = sample["chosen"] | |
| rejected_summary = sample["rejected"] | |
| if chosen_summary == rejected_summary: | |
| continue | |
| if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: | |
| continue | |
| pair["chosen"] = prompt + "\n" + chosen_summary | |
| pair["rejected"] = prompt + "\n" + rejected_summary | |
| pairs.append(pair) | |
| return pairs | |
| class PairwiseDataset(Dataset): | |
| def __init__(self, pairs, tokenizer, max_length): | |
| self.chosen_input_ids = [] | |
| self.chosen_attn_masks = [] | |
| self.rejected_input_ids = [] | |
| self.rejected_attn_masks = [] | |
| for pair in pairs: | |
| chosen, rejected = pair["chosen"], pair["rejected"] | |
| chosen_encodings_dict = tokenizer( | |
| "<|startoftext|>" + chosen + "<|endoftext|>", | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| rejected_encodings_dict = tokenizer( | |
| "<|startoftext|>" + rejected + "<|endoftext|>", | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): | |
| self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) | |
| self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) | |
| self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) | |
| self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) | |
| def __len__(self): | |
| return len(self.chosen_input_ids) | |
| def __getitem__(self, idx): | |
| return ( | |
| self.chosen_input_ids[idx], | |
| self.chosen_attn_masks[idx], | |
| self.rejected_input_ids[idx], | |
| self.rejected_attn_masks[idx], | |
| ) | |
| class DataCollatorReward: | |
| def __call__(self, data): | |
| batch = {} | |
| batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) | |
| batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) | |
| batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) | |
| return batch | |
| if __name__ == "__main__": | |
| tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] | |
| model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") | |
| model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) | |
| max_length = 550 | |
| val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test") | |
| dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) | |
| from torch.utils.data import DataLoader | |
| dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()) | |
| model.cuda() | |
| model.eval() | |
| model.half() | |
| correct = 0 | |
| chosen_list = [] | |
| reject_list = [] | |
| with torch.no_grad(): | |
| for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): | |
| for x in batch: | |
| batch[x] = batch[x].cuda() | |
| outputs = model(**batch) | |
| correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"]) | |
| chosen_list.append(outputs["chosen_end_scores"].cpu()) | |
| reject_list.append(outputs["rejected_end_scores"].cpu()) | |
| print("Total accuracy: ", correct / len(dev_dataset)) | |