File size: 2,993 Bytes
8b414b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer

from src.utils import get_target_columns


def collate_fn(data):
    input_ids = []
    token_type_ids = []
    attention_mask = []
    labels = []

    for item in data:
        input_ids.append(item['input_ids'].squeeze())
        token_type_ids.append(item['token_type_ids'].squeeze())
        attention_mask.append(item['attention_mask'].squeeze())
        labels.append(item['labels'].squeeze())

    return {
        "input_ids": torch.stack(input_ids),
        'token_type_ids': torch.stack(token_type_ids),
        'attention_mask': torch.stack(attention_mask),
        'labels': torch.stack(labels)
    }


class ClassificationDataset(Dataset):
    def __init__(self, tokenizer: BertTokenizer, df: pd.DataFrame, config: dict):
        self.config = config
        self.tokenizer = tokenizer

        self.df = df

        self.features = self.tokenizer(
            text=df.full_text.tolist(),
            max_length=self.config['max_length'],
            padding=True,
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt',
        )
        if 'cohesion' in self.df.columns:
            self.features['labels'] = torch.as_tensor(df[get_target_columns()].values, dtype=torch.float32)
        else:
            data = torch.ones(size=(len(df), 6), dtype=torch.float32) * -1.
            self.features['labels'] = data

    def __getitem__(self, item):
        """Returns dict with input_ids, token_type_ids, attention_mask, labels
        """
        return {
            'input_ids': self.features['input_ids'][item],
            'token_type_ids': self.features['token_type_ids'][item],
            'attention_mask': self.features['attention_mask'][item],
            'labels': self.features['labels'][item]
        }

    def __len__(self):
        return len(self.df)


class ClassificationDataloader(pl.LightningDataModule):
    def __init__(
            self,
            tokenizer: BertTokenizer,
            train_df: pd.DataFrame,
            val_df: pd.DataFrame,
            config: dict
    ):
        super().__init__()
        self.config = config

        self.train_data = ClassificationDataset(tokenizer, train_df, config)
        self.val_data = ClassificationDataset(tokenizer, val_df, config)

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_data,
            shuffle=True,
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_workers'],
            collate_fn=collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_data,
            shuffle=False,
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_workers'],
            collate_fn=collate_fn
        )