Sasidhar's picture
Upload 16 files
826f9a4 verified
raw
history blame
2.15 kB
import torch
from dateutil.parser import parse as parse_date
from sklearn.model_selection import train_test_split
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
from torch.utils.data import Dataset
class GroundingDataset(Dataset):
def __init__(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
encoding = self.tokenizer(
item["question"],
text_pair=item["answer"] + " [SEP] " + item["context"],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"labels": torch.tensor(item["label"])
}
class GroundingTrainer:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
self.model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
def train(self, dataset):
train_data, val_data = train_test_split(dataset, test_size=0.2)
trainer = Trainer(
model=self.model,
args=TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=8,
evaluation_strategy="epoch",
logging_dir="./logs"
),
train_dataset=GroundingDataset(train_data, self.tokenizer),
eval_dataset=GroundingDataset(val_data, self.tokenizer)
)
trainer.train()
self.model.save_pretrained("./grounding_detector")
self.tokenizer.save_pretrained("./grounding_detector")