"""
This script trains sentence transformers with a batch hard loss function.

The TREC dataset will be automatically downloaded and put in the datasets/ directory

Usual triplet loss takes 3 inputs: anchor, positive, negative and optimizes the network such that
the positive sentence is closer to the anchor than the negative sentence. However, a challenge here is
to select good triplets. If the negative sentence is selected randomly, the training objective is often
too easy and the network fails to learn good representations.

Batch hard triplet loss (https://arxiv.org/abs/1703.07737) creates triplets on the fly. It requires that the
data is labeled (e.g. labels 1, 2, 3) and we assume that samples with the same label are similar:

In a batch, it checks for sent1 with label 1 what is the other sentence with label 1 that is the furthest (hard positive)
which sentence with another label is the closest (hard negative example). It then tries to optimize this, i.e.
all sentences with the same label should be close and sentences for different labels should be clearly seperated.
"""

from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util
from sentence_transformers.datasets import SentenceLabelDataset
from torch.utils.data import DataLoader
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import TripletEvaluator
from datetime import datetime


import logging
import os
import random
from collections import defaultdict

logging.basicConfig(
    format="%(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[LoggingHandler()],
)

# Inspired from torchnlp
def trec_dataset(
    directory="datasets/trec/",
    train_filename="train_5500.label",
    test_filename="TREC_10.label",
    validation_dataset_nb=500,
    urls=[
        "https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label",
        "https://cogcomp.seas.upenn.edu/Data/QA/QC/TREC_10.label",
    ],
):
    os.makedirs(directory, exist_ok=True)

    ret = []
    for url, filename in zip(urls, [train_filename, test_filename]):
        full_path = os.path.join(directory, filename)
        if not os.path.exists(full_path):
            util.http_get(url, full_path)

        examples = []
        label_map = {}
        guid=1
        for line in open(full_path, "rb"):
            # there is one non-ASCII byte: sisterBADBYTEcity; replaced with space
            label, _, text = line.replace(b"\xf0", b" ").strip().decode().partition(" ")

            if label not in label_map:
                label_map[label] = len(label_map)

            label_id = label_map[label]
            guid += 1
            examples.append(InputExample(guid=guid, texts=[text], label=label_id))
        ret.append(examples)

    train_set, test_set = ret
    dev_set = None

    # Create a dev set from train set
    if validation_dataset_nb > 0:
        dev_set = train_set[-validation_dataset_nb:]
        train_set = train_set[:-validation_dataset_nb]

    # For dev & test set, we return triplets (anchor, positive, negative)
    random.seed(42) #Fix seed, so that we always get the same triplets
    dev_triplets = triplets_from_labeled_dataset(dev_set)
    test_triplets = triplets_from_labeled_dataset(test_set)

    return train_set, dev_triplets, test_triplets


def triplets_from_labeled_dataset(input_examples):
    # Create triplets for a [(label, sentence), (label, sentence)...] dataset
    # by using each example as an anchor and selecting randomly a
    # positive instance with the same label and a negative instance with a different label
    triplets = []
    label2sentence = defaultdict(list)
    for inp_example in input_examples:
        label2sentence[inp_example.label].append(inp_example)

    for inp_example in input_examples:
        anchor = inp_example

        if len(label2sentence[inp_example.label]) < 2: #We need at least 2 examples per label to create a triplet
            continue

        positive = None
        while positive is None or positive.guid == anchor.guid:
            positive = random.choice(label2sentence[inp_example.label])

        negative = None
        while negative is None or negative.label == anchor.label:
            negative = random.choice(input_examples)

        triplets.append(InputExample(texts=[anchor.texts[0], positive.texts[0], negative.texts[0]]))

    return triplets



# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = 'all-distilroberta-v1'

### Create a torch.DataLoader that passes training batch instances to our model
train_batch_size = 32
output_path = (
    "output/finetune-batch-hard-trec-"
    + model_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
num_epochs = 1

logging.info("Loading TREC dataset")
train_set, dev_set, test_set = trec_dataset()

# We create a special dataset "SentenceLabelDataset" to wrap out train_set
# It will yield batches that contain at least two samples with the same label
train_data_sampler = SentenceLabelDataset(train_set)
train_dataloader = DataLoader(train_data_sampler, batch_size=32, drop_last=True)


# Load pretrained model
logging.info("Load model")
model = SentenceTransformer(model_name)


### Triplet losses ####################
### There are 4 triplet loss variants:
### - BatchHardTripletLoss
### - BatchHardSoftMarginTripletLoss
### - BatchSemiHardTripletLoss
### - BatchAllTripletLoss
#######################################

train_loss = losses.BatchAllTripletLoss(model=model)
#train_loss = losses.BatchHardTripletLoss(model=model)
#train_loss = losses.BatchHardSoftMarginTripletLoss(model=model)
#train_loss = losses.BatchSemiHardTripletLoss(model=model)


logging.info("Read TREC val dataset")
dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='trec-dev')

logging.info("Performance before fine-tuning:")
dev_evaluator(model)

warmup_steps = int(len(train_dataloader) * num_epochs  * 0.1)  # 10% of train data

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=1000,
    warmup_steps=warmup_steps,
    output_path=output_path,
)

##############################################################################
#
# Load the stored model and evaluate its performance on TREC dataset
#
##############################################################################

logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(test_set, name='trec-test')
model.evaluate(test_evaluator)