Spaces:
Sleeping
Sleeping
import torch | |
from dateutil.parser import parse as parse_date | |
from sklearn.model_selection import train_test_split | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
TrainingArguments, | |
Trainer | |
) | |
from torch.utils.data import Dataset | |
class GroundingDataset(Dataset): | |
def __init__(self, data, tokenizer, max_length=512): | |
self.data = data | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
encoding = self.tokenizer( | |
item["question"], | |
text_pair=item["answer"] + " [SEP] " + item["context"], | |
padding="max_length", | |
truncation=True, | |
max_length=self.max_length, | |
return_tensors="pt" | |
) | |
return { | |
"input_ids": encoding["input_ids"].squeeze(), | |
"attention_mask": encoding["attention_mask"].squeeze(), | |
"labels": torch.tensor(item["label"]) | |
} | |
class GroundingTrainer: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
"distilbert-base-uncased", num_labels=2 | |
) | |
def train(self, dataset): | |
train_data, val_data = train_test_split(dataset, test_size=0.2) | |
trainer = Trainer( | |
model=self.model, | |
args=TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=3, | |
per_device_train_batch_size=8, | |
evaluation_strategy="epoch", | |
logging_dir="./logs" | |
), | |
train_dataset=GroundingDataset(train_data, self.tokenizer), | |
eval_dataset=GroundingDataset(val_data, self.tokenizer) | |
) | |
trainer.train() | |
self.model.save_pretrained("./grounding_detector") | |
self.tokenizer.save_pretrained("./grounding_detector") | |