File size: 6,996 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
In this example we use a knowledge distillation setup. Sebastian Hofstätter et al. trained in https://arxiv.org/abs/2010.02666
an ensemble of large Transformer models for the MS MARCO datasets and combines the scores from a BERT-base, BERT-large, and ALBERT-large model.
We use the logits scores from the ensemble to train a smaller model. We found that the MiniLM model gives the best performance while
offering the highest speed.
The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.
This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
Running this script:
python train_cross-encoder-v2.py
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import torch
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#First, we define the transformer model we want to fine-tune
model_name = 'microsoft/MiniLM-L12-H384-uncased'
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/training_ms-marco_cross-encoder-v2-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#We set num_labels=1 and set the activation function to Identiy, so that we get the raw logits
model = CrossEncoder(model_name, num_labels=1, max_length=512, default_activation_function=torch.nn.Identity())
### Now we read the MS Marco dataset
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)
#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {}
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download collection.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(collection_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
### Read the train queries, store in queries dict
queries = {}
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download queries.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(queries_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
### Now we create our dev data
train_samples = []
dev_samples = {}
# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200
# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
if not os.path.exists(train_eval_filepath):
logging.info("Download "+os.path.basename(train_eval_filepath))
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath)
with gzip.open(train_eval_filepath, 'rt') as fIn:
for line in fIn:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
dev_qids = set(dev_samples.keys())
# Read our training file
# As input examples, we provide the (query, passage) pair together with the logits score from the teacher ensemble
teacher_logits_filepath = os.path.join(data_folder, 'bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
train_samples = []
if not os.path.exists(teacher_logits_filepath):
util.http_get('https://zenodo.org/record/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1', teacher_logits_filepath)
with open(teacher_logits_filepath) as fIn:
for line in fIn:
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
if qid in dev_qids: #Skip queries in our dev dataset
continue
train_samples.append(InputExample(texts=[queries[qid], corpus[pid1]], label=float(pos_score)))
train_samples.append(InputExample(texts=[queries[qid], corpus[pid2]], label=float(neg_score)))
# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size, drop_last=True)
# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores like F1 (finding relevant passages) and Average Precision
evaluator = CERerankingEvaluator(dev_samples, name='train-eval')
# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_dataloader=train_dataloader,
loss_fct=torch.nn.MSELoss(),
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=5000,
warmup_steps=warmup_steps,
output_path=model_save_path,
optimizer_params={'lr': 7e-6},
use_amp=True)
#Save latest model
model.save(model_save_path+'-latest') |