|
""" |
|
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking). |
|
|
|
In this example we use a knowledge distillation setup. Sebastian Hofstätter et al. trained in https://arxiv.org/abs/2010.02666 |
|
an ensemble of large Transformer models for the MS MARCO datasets and combines the scores from a BERT-base, BERT-large, and ALBERT-large model. |
|
|
|
We use the logits scores from the ensemble to train a smaller model. We found that the MiniLM model gives the best performance while |
|
offering the highest speed. |
|
|
|
The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages |
|
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder |
|
for scoring. You sort the results then according to the output of the CrossEncoder. |
|
|
|
This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking. |
|
|
|
Running this script: |
|
python train_cross-encoder-v2.py |
|
""" |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers import LoggingHandler, util |
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator |
|
from sentence_transformers import InputExample |
|
import logging |
|
from datetime import datetime |
|
import gzip |
|
import os |
|
import tarfile |
|
import tqdm |
|
import torch |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
|
|
|
|
model_name = 'microsoft/MiniLM-L12-H384-uncased' |
|
train_batch_size = 32 |
|
num_epochs = 1 |
|
model_save_path = 'output/training_ms-marco_cross-encoder-v2-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
|
|
|
|
model = CrossEncoder(model_name, num_labels=1, max_length=512, default_activation_function=torch.nn.Identity()) |
|
|
|
|
|
|
|
data_folder = 'msmarco-data' |
|
os.makedirs(data_folder, exist_ok=True) |
|
|
|
|
|
|
|
corpus = {} |
|
collection_filepath = os.path.join(data_folder, 'collection.tsv') |
|
if not os.path.exists(collection_filepath): |
|
tar_filepath = os.path.join(data_folder, 'collection.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download collection.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
with open(collection_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
corpus[pid] = passage |
|
|
|
|
|
|
|
queries = {} |
|
queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
|
if not os.path.exists(queries_filepath): |
|
tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download queries.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
|
|
with open(queries_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
queries[qid] = query |
|
|
|
|
|
|
|
train_samples = [] |
|
dev_samples = {} |
|
|
|
|
|
|
|
num_dev_queries = 200 |
|
num_max_dev_negatives = 200 |
|
|
|
|
|
|
|
|
|
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz') |
|
if not os.path.exists(train_eval_filepath): |
|
logging.info("Download "+os.path.basename(train_eval_filepath)) |
|
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath) |
|
|
|
with gzip.open(train_eval_filepath, 'rt') as fIn: |
|
for line in fIn: |
|
qid, pos_id, neg_id = line.strip().split() |
|
|
|
if qid not in dev_samples and len(dev_samples) < num_dev_queries: |
|
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()} |
|
|
|
if qid in dev_samples: |
|
dev_samples[qid]['positive'].add(corpus[pos_id]) |
|
|
|
if len(dev_samples[qid]['negative']) < num_max_dev_negatives: |
|
dev_samples[qid]['negative'].add(corpus[neg_id]) |
|
|
|
dev_qids = set(dev_samples.keys()) |
|
|
|
|
|
|
|
teacher_logits_filepath = os.path.join(data_folder, 'bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv') |
|
train_samples = [] |
|
if not os.path.exists(teacher_logits_filepath): |
|
util.http_get('https://zenodo.org/record/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1', teacher_logits_filepath) |
|
|
|
with open(teacher_logits_filepath) as fIn: |
|
for line in fIn: |
|
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t") |
|
|
|
if qid in dev_qids: |
|
continue |
|
|
|
train_samples.append(InputExample(texts=[queries[qid], corpus[pid1]], label=float(pos_score))) |
|
train_samples.append(InputExample(texts=[queries[qid], corpus[pid2]], label=float(neg_score))) |
|
|
|
|
|
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size, drop_last=True) |
|
|
|
|
|
|
|
evaluator = CERerankingEvaluator(dev_samples, name='train-eval') |
|
|
|
|
|
warmup_steps = 5000 |
|
logging.info("Warmup-steps: {}".format(warmup_steps)) |
|
|
|
|
|
|
|
model.fit(train_dataloader=train_dataloader, |
|
loss_fct=torch.nn.MSELoss(), |
|
evaluator=evaluator, |
|
epochs=num_epochs, |
|
evaluation_steps=5000, |
|
warmup_steps=warmup_steps, |
|
output_path=model_save_path, |
|
optimizer_params={'lr': 7e-6}, |
|
use_amp=True) |
|
|
|
|
|
model.save(model_save_path+'-latest') |