Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| import pandas as pd | |
| import torch | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset | |
| def get_dataset_from_jsonl(jsonl_file, return_summary=True): | |
| # if return_summary is True, return a list of posts with summary concatenated | |
| # if return_summary is False, return a list of posts and a list of summaries | |
| with open(jsonl_file, "r") as f: | |
| dataset = [json.loads(line) for line in f] | |
| post_list = [] | |
| summary_list = [] | |
| for d in dataset: | |
| if return_summary: | |
| post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}" | |
| else: | |
| post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: " | |
| summary_list.append(d["summary"]) | |
| post_list.append(post) | |
| if not return_summary: | |
| return post_list, summary_list | |
| return post_list | |
| class TLDRDataset(Dataset): | |
| def __init__(self, train_path, tokenizer, split, max_length=550): | |
| self.post_list = [] | |
| dataset = load_dataset(train_path, split=split) | |
| for sample in dataset: | |
| self.post_list.append(sample["prompt"] + sample["label"]) | |
| if "valid" in split: | |
| self.post_list = self.post_list[0:2000] | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.input_ids = [] | |
| self.attn_masks = [] | |
| def __len__(self): | |
| return len(self.post_list) | |
| def __getitem__(self, idx): | |
| txt = self.post_list[idx] | |
| encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length") | |
| input_ids = torch.tensor(encodings_dict["input_ids"]) | |
| attn_masks = torch.tensor(encodings_dict["attention_mask"]) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attn_masks, | |
| "labels": input_ids, | |
| } | |
| class ComparisonDataset(Dataset): | |
| def __init__(self, comparison_path, tokenizer, max_length=550): | |
| with open(comparison_path, "r") as f: | |
| dataset = [json.loads(line) for line in f] | |
| self.tokenizer = tokenizer | |
| self.post_list = [] | |
| self.summaries_0 = [] | |
| self.summaries_1 = [] | |
| self.labels = [] | |
| self.max_length = max_length | |
| def make_text(post, summarize): | |
| return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}" | |
| for sample in dataset: # chosen summary is always the first one | |
| self.post_list.append(sample["info"]["post"]) | |
| # NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]` | |
| if sample["choice"] == 0: | |
| self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"])) | |
| self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"])) | |
| else: | |
| self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"])) | |
| self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"])) | |
| self.labels.append(0) | |
| def __len__(self): | |
| return len(self.post_list) | |
| def __getitem__(self, idx): | |
| summ0 = self.summaries_0[idx] | |
| summ1 = self.summaries_1[idx] | |
| encodings_dict = self.tokenizer( | |
| [summ0, summ1], | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| ) | |
| input_ids = torch.tensor(encodings_dict["input_ids"]) | |
| attention_mask = torch.tensor(encodings_dict["attention_mask"]) | |
| return {"input_ids": input_ids, "attention_mask": attention_mask} | |
| class AllSummDataset(Dataset): | |
| def __init__(self, train_path, tokenizer, split, max_length=1024): | |
| df = pd.read_parquet(train_path) | |
| if split == "valid": | |
| df = df.sample(n=5000) | |
| self.summarizes = [] | |
| for i, row in df.iterrows(): | |
| self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}") | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.input_ids = [] | |
| self.attn_masks = [] | |
| def __len__(self): | |
| return len(self.summarizes) | |
| def __getitem__(self, idx): | |
| txt = self.summarizes[idx] | |
| encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length") | |
| input_ids = torch.tensor(encodings_dict["input_ids"]) | |
| attn_masks = torch.tensor(encodings_dict["attention_mask"]) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attn_masks, | |
| "labels": input_ids, | |
| } | |