|
""" |
|
This script trains sentence transformers with a batch hard loss function. |
|
|
|
The TREC dataset will be automatically downloaded and put in the datasets/ directory |
|
|
|
Usual triplet loss takes 3 inputs: anchor, positive, negative and optimizes the network such that |
|
the positive sentence is closer to the anchor than the negative sentence. However, a challenge here is |
|
to select good triplets. If the negative sentence is selected randomly, the training objective is often |
|
too easy and the network fails to learn good representations. |
|
|
|
Batch hard triplet loss (https://arxiv.org/abs/1703.07737) creates triplets on the fly. It requires that the |
|
data is labeled (e.g. labels 1, 2, 3) and we assume that samples with the same label are similar: |
|
|
|
In a batch, it checks for sent1 with label 1 what is the other sentence with label 1 that is the furthest (hard positive) |
|
which sentence with another label is the closest (hard negative example). It then tries to optimize this, i.e. |
|
all sentences with the same label should be close and sentences for different labels should be clearly seperated. |
|
""" |
|
|
|
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util |
|
from sentence_transformers.datasets import SentenceLabelDataset |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers.readers import InputExample |
|
from sentence_transformers.evaluation import TripletEvaluator |
|
from datetime import datetime |
|
|
|
|
|
import logging |
|
import os |
|
import random |
|
from collections import defaultdict |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()], |
|
) |
|
|
|
|
|
def trec_dataset( |
|
directory="datasets/trec/", |
|
train_filename="train_5500.label", |
|
test_filename="TREC_10.label", |
|
validation_dataset_nb=500, |
|
urls=[ |
|
"https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label", |
|
"https://cogcomp.seas.upenn.edu/Data/QA/QC/TREC_10.label", |
|
], |
|
): |
|
os.makedirs(directory, exist_ok=True) |
|
|
|
ret = [] |
|
for url, filename in zip(urls, [train_filename, test_filename]): |
|
full_path = os.path.join(directory, filename) |
|
if not os.path.exists(full_path): |
|
util.http_get(url, full_path) |
|
|
|
examples = [] |
|
label_map = {} |
|
guid=1 |
|
for line in open(full_path, "rb"): |
|
|
|
label, _, text = line.replace(b"\xf0", b" ").strip().decode().partition(" ") |
|
|
|
if label not in label_map: |
|
label_map[label] = len(label_map) |
|
|
|
label_id = label_map[label] |
|
guid += 1 |
|
examples.append(InputExample(guid=guid, texts=[text], label=label_id)) |
|
ret.append(examples) |
|
|
|
train_set, test_set = ret |
|
dev_set = None |
|
|
|
|
|
if validation_dataset_nb > 0: |
|
dev_set = train_set[-validation_dataset_nb:] |
|
train_set = train_set[:-validation_dataset_nb] |
|
|
|
|
|
random.seed(42) |
|
dev_triplets = triplets_from_labeled_dataset(dev_set) |
|
test_triplets = triplets_from_labeled_dataset(test_set) |
|
|
|
return train_set, dev_triplets, test_triplets |
|
|
|
|
|
def triplets_from_labeled_dataset(input_examples): |
|
|
|
|
|
|
|
triplets = [] |
|
label2sentence = defaultdict(list) |
|
for inp_example in input_examples: |
|
label2sentence[inp_example.label].append(inp_example) |
|
|
|
for inp_example in input_examples: |
|
anchor = inp_example |
|
|
|
if len(label2sentence[inp_example.label]) < 2: |
|
continue |
|
|
|
positive = None |
|
while positive is None or positive.guid == anchor.guid: |
|
positive = random.choice(label2sentence[inp_example.label]) |
|
|
|
negative = None |
|
while negative is None or negative.label == anchor.label: |
|
negative = random.choice(input_examples) |
|
|
|
triplets.append(InputExample(texts=[anchor.texts[0], positive.texts[0], negative.texts[0]])) |
|
|
|
return triplets |
|
|
|
|
|
|
|
|
|
model_name = 'all-distilroberta-v1' |
|
|
|
|
|
train_batch_size = 32 |
|
output_path = ( |
|
"output/finetune-batch-hard-trec-" |
|
+ model_name |
|
+ "-" |
|
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
) |
|
num_epochs = 1 |
|
|
|
logging.info("Loading TREC dataset") |
|
train_set, dev_set, test_set = trec_dataset() |
|
|
|
|
|
|
|
train_data_sampler = SentenceLabelDataset(train_set) |
|
train_dataloader = DataLoader(train_data_sampler, batch_size=32, drop_last=True) |
|
|
|
|
|
|
|
logging.info("Load model") |
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loss = losses.BatchAllTripletLoss(model=model) |
|
|
|
|
|
|
|
|
|
|
|
logging.info("Read TREC val dataset") |
|
dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='trec-dev') |
|
|
|
logging.info("Performance before fine-tuning:") |
|
dev_evaluator(model) |
|
|
|
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) |
|
|
|
|
|
model.fit( |
|
train_objectives=[(train_dataloader, train_loss)], |
|
evaluator=dev_evaluator, |
|
epochs=num_epochs, |
|
evaluation_steps=1000, |
|
warmup_steps=warmup_steps, |
|
output_path=output_path, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("Evaluating model on test set") |
|
test_evaluator = TripletEvaluator.from_input_examples(test_set, name='trec-test') |
|
model.evaluate(test_evaluator) |
|
|