|
""" |
|
This examples show how to train a Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking). |
|
|
|
The queries and passages are passed independently to the transformer network to produce fixed sized embeddings. |
|
These embeddings can then be compared using cosine-similarity to find matching passages for a given query. |
|
|
|
For training, we use MultipleNegativesRankingLoss. There, we pass triplets in the format: |
|
(query, positive_passage, negative_passage) |
|
|
|
Negative passage are hard negative examples, that were mined using different dense embedding methods and lexical search methods. |
|
Each positive and negative passage comes with a score from a Cross-Encoder. This allows denoising, i.e. removing false negative |
|
passages that are actually relevant for the query. |
|
|
|
With a distilbert-base-uncased model, it should achieve a performance of about 33.79 MRR@10 on the MSMARCO Passages Dev-Corpus |
|
|
|
Running this script: |
|
python train_bi-encoder-v3.py |
|
""" |
|
import sys |
|
import json |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample |
|
import logging |
|
from datetime import datetime |
|
import gzip |
|
import os |
|
import tarfile |
|
from collections import defaultdict |
|
from torch.utils.data import IterableDataset |
|
import tqdm |
|
from torch.utils.data import Dataset |
|
import random |
|
import pickle |
|
import argparse |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--train_batch_size", default=64, type=int) |
|
parser.add_argument("--max_seq_length", default=300, type=int) |
|
parser.add_argument("--model_name", required=True) |
|
parser.add_argument("--max_passages", default=0, type=int) |
|
parser.add_argument("--epochs", default=10, type=int) |
|
parser.add_argument("--pooling", default="mean") |
|
parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all") |
|
parser.add_argument("--warmup_steps", default=1000, type=int) |
|
parser.add_argument("--lr", default=2e-5, type=float) |
|
parser.add_argument("--num_negs_per_system", default=5, type=int) |
|
parser.add_argument("--use_pre_trained_model", default=False, action="store_true") |
|
parser.add_argument("--use_all_queries", default=False, action="store_true") |
|
parser.add_argument("--ce_score_margin", default=3.0, type=float) |
|
args = parser.parse_args() |
|
|
|
print(args) |
|
|
|
|
|
model_name = args.model_name |
|
|
|
train_batch_size = args.train_batch_size |
|
max_seq_length = args.max_seq_length |
|
ce_score_margin = args.ce_score_margin |
|
num_negs_per_system = args.num_negs_per_system |
|
num_epochs = args.epochs |
|
|
|
|
|
if args.use_pre_trained_model: |
|
logging.info("use pretrained SBERT model") |
|
model = SentenceTransformer(model_name) |
|
model.max_seq_length = max_seq_length |
|
else: |
|
logging.info("Create new SBERT model") |
|
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) |
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) |
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
|
|
|
model_save_path = 'output/train_bi-encoder-mnrl-{}-margin_{:.1f}-{}'.format(model_name.replace("/", "-"), ce_score_margin, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
|
|
|
|
|
|
|
data_folder = 'msmarco-data' |
|
|
|
|
|
corpus = {} |
|
collection_filepath = os.path.join(data_folder, 'collection.tsv') |
|
if not os.path.exists(collection_filepath): |
|
tar_filepath = os.path.join(data_folder, 'collection.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download collection.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
logging.info("Read corpus: collection.tsv") |
|
with open(collection_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
pid = int(pid) |
|
corpus[pid] = passage |
|
|
|
|
|
|
|
queries = {} |
|
queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
|
if not os.path.exists(queries_filepath): |
|
tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download queries.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
|
|
with open(queries_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
qid = int(qid) |
|
queries[qid] = query |
|
|
|
|
|
|
|
|
|
ce_scores_file = os.path.join(data_folder, 'cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz') |
|
if not os.path.exists(ce_scores_file): |
|
logging.info("Download cross-encoder scores file") |
|
util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz', ce_scores_file) |
|
|
|
logging.info("Load CrossEncoder scores dict") |
|
with gzip.open(ce_scores_file, 'rb') as fIn: |
|
ce_scores = pickle.load(fIn) |
|
|
|
|
|
hard_negatives_filepath = os.path.join(data_folder, 'msmarco-hard-negatives.jsonl.gz') |
|
if not os.path.exists(hard_negatives_filepath): |
|
logging.info("Download cross-encoder scores file") |
|
util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz', hard_negatives_filepath) |
|
|
|
|
|
logging.info("Read hard negatives train file") |
|
train_queries = {} |
|
negs_to_use = None |
|
with gzip.open(hard_negatives_filepath, 'rt') as fIn: |
|
for line in tqdm.tqdm(fIn): |
|
data = json.loads(line) |
|
|
|
|
|
qid = data['qid'] |
|
pos_pids = data['pos'] |
|
|
|
if len(pos_pids) == 0: |
|
continue |
|
|
|
pos_min_ce_score = min([ce_scores[qid][pid] for pid in data['pos']]) |
|
ce_score_threshold = pos_min_ce_score - ce_score_margin |
|
|
|
|
|
neg_pids = set() |
|
if negs_to_use is None: |
|
if args.negs_to_use is not None: |
|
negs_to_use = args.negs_to_use.split(",") |
|
else: |
|
negs_to_use = list(data['neg'].keys()) |
|
logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) |
|
|
|
for system_name in negs_to_use: |
|
if system_name not in data['neg']: |
|
continue |
|
|
|
system_negs = data['neg'][system_name] |
|
negs_added = 0 |
|
for pid in system_negs: |
|
if ce_scores[qid][pid] > ce_score_threshold: |
|
continue |
|
|
|
if pid not in neg_pids: |
|
neg_pids.add(pid) |
|
negs_added += 1 |
|
if negs_added >= num_negs_per_system: |
|
break |
|
|
|
if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): |
|
train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids} |
|
|
|
del ce_scores |
|
|
|
logging.info("Train queries: {}".format(len(train_queries))) |
|
|
|
|
|
|
|
|
|
class MSMARCODataset(Dataset): |
|
def __init__(self, queries, corpus): |
|
self.queries = queries |
|
self.queries_ids = list(queries.keys()) |
|
self.corpus = corpus |
|
|
|
for qid in self.queries: |
|
self.queries[qid]['pos'] = list(self.queries[qid]['pos']) |
|
self.queries[qid]['neg'] = list(self.queries[qid]['neg']) |
|
random.shuffle(self.queries[qid]['neg']) |
|
|
|
def __getitem__(self, item): |
|
query = self.queries[self.queries_ids[item]] |
|
query_text = query['query'] |
|
|
|
pos_id = query['pos'].pop(0) |
|
pos_text = self.corpus[pos_id] |
|
query['pos'].append(pos_id) |
|
|
|
neg_id = query['neg'].pop(0) |
|
neg_text = self.corpus[neg_id] |
|
query['neg'].append(neg_id) |
|
|
|
return InputExample(texts=[query_text, pos_text, neg_text]) |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
|
|
train_dataset = MSMARCODataset(train_queries, corpus=corpus) |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) |
|
train_loss = losses.MultipleNegativesRankingLoss(model=model) |
|
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)], |
|
epochs=num_epochs, |
|
warmup_steps=args.warmup_steps, |
|
use_amp=True, |
|
checkpoint_path=model_save_path, |
|
checkpoint_save_steps=len(train_dataloader), |
|
optimizer_params = {'lr': args.lr}, |
|
) |
|
|
|
|
|
model.save(model_save_path) |
|
|