|
import transformers |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
import logging |
|
logging.basicConfig(level=logging.ERROR) |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import torch.optim as optim |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
MAX_LEN = 100 |
|
TRAIN_BATCH_SIZE = 4 |
|
VALID_BATCH_SIZE = 4 |
|
EPOCHS = 1 |
|
LEARNING_RATE = 1e-05 |
|
tokenizer_DB = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True) |
|
|
|
|
|
|
|
|
|
|
|
class BinaryLabel(Dataset): |
|
|
|
def __init__(self, dataframe, tokenizer, max_len): |
|
self.tokenizer = tokenizer_DB |
|
self.data = dataframe |
|
self.text = dataframe.text |
|
self.targets = self.data.label |
|
self.max_len = max_len |
|
|
|
def __len__(self): |
|
return len(self.text) |
|
|
|
def __getitem__(self, index): |
|
text = str(self.text[index]) |
|
text = " ".join(text.split()) |
|
|
|
inputs = self.tokenizer.encode_plus( |
|
text, |
|
None, |
|
add_special_tokens=True, |
|
max_length=self.max_len, |
|
pad_to_max_length=True, |
|
return_token_type_ids=True |
|
) |
|
ids = inputs['input_ids'] |
|
mask = inputs['attention_mask'] |
|
token_type_ids = inputs["token_type_ids"] |
|
|
|
|
|
return { |
|
'ids': torch.tensor(ids, dtype=torch.long), |
|
'mask': torch.tensor(mask, dtype=torch.long), |
|
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), |
|
'targets': torch.tensor(self.targets[index], dtype=torch.float) |
|
} |
|
|
|
train_params = {'batch_size': TRAIN_BATCH_SIZE, |
|
'shuffle': True, |
|
'num_workers': 0 |
|
} |
|
|
|
test_params = {'batch_size': VALID_BATCH_SIZE, |
|
'shuffle': True, |
|
'num_workers': 0 |
|
} |
|
|
|
training_set = BinaryLabel(train_df_DB, tokenizer, MAX_LEN) |
|
testing_set = BinaryLabel(test_df_DB, tokenizer, MAX_LEN) |
|
|
|
training_loader = DataLoader(training_set, **train_params) |
|
testing_loader = DataLoader(testing_set, **test_params) |