SentenceTransformer / examples /training /ms_marco /train_cross-encoder_scratch.py
lengocduc195's picture
pushNe
2359bda
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.
The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.
This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
Running this script:
python train_cross-encoder.py
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#First, we define the transformer model we want to fine-tune
model_name = 'distilroberta-base'
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# We train the network with as a binary label task
# Given [query, passage] is the label 0 = irrelevant or 1 = relevant?
# We use a positive-to-negative ratio: For 1 positive sample (label 1) we include 4 negative samples (label 0)
# in our training setup. For the negative samples, we use the triplets provided by MS Marco that
# specify (query, positive sample, negative sample).
pos_neg_ration = 4
# Maximal number of training samples we want to use
max_train_samples = 2e7
#We set num_labels=1, which predicts a continous score between 0 and 1
model = CrossEncoder(model_name, num_labels=1, max_length=512)
### Now we read the MS Marco dataset
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)
#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {}
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download collection.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(collection_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
### Read the train queries, store in queries dict
queries = {}
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download queries.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(queries_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
### Now we create our training & dev data
train_samples = []
dev_samples = {}
# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200
# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
if not os.path.exists(train_eval_filepath):
logging.info("Download "+os.path.basename(train_eval_filepath))
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath)
with gzip.open(train_eval_filepath, 'rt') as fIn:
for line in fIn:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
# Read our training file
train_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train.tsv.gz')
if not os.path.exists(train_filepath):
logging.info("Download "+os.path.basename(train_filepath))
util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz', train_filepath)
cnt = 0
with gzip.open(train_filepath, 'rt') as fIn:
for line in tqdm.tqdm(fIn, unit_scale=True):
qid, pos_id, neg_id = line.strip().split()
if qid in dev_samples:
continue
query = queries[qid]
if (cnt % (pos_neg_ration+1)) == 0:
passage = corpus[pos_id]
label = 1
else:
passage = corpus[neg_id]
label = 0
train_samples.append(InputExample(texts=[query, passage], label=label))
cnt += 1
if cnt >= max_train_samples:
break
# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores like F1 (finding relevant passages) and Average Precision
evaluator = CERerankingEvaluator(dev_samples, name='train-eval')
# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=10000,
warmup_steps=warmup_steps,
output_path=model_save_path,
use_amp=True)
#Save latest model
model.save(model_save_path+'-latest')