|  | """ | 
					
						
						|  | This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking). | 
					
						
						|  |  | 
					
						
						|  | The query and the passage are passed simoultanously to a Transformer network. The network then returns | 
					
						
						|  | a score between 0 and 1 how relevant the passage is for a given query. | 
					
						
						|  |  | 
					
						
						|  | The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages | 
					
						
						|  | for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder | 
					
						
						|  | for scoring. You sort the results then according to the output of the CrossEncoder. | 
					
						
						|  |  | 
					
						
						|  | This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking. | 
					
						
						|  |  | 
					
						
						|  | Running this script: | 
					
						
						|  | python train_cross-encoder.py | 
					
						
						|  | """ | 
					
						
						|  | from torch.utils.data import DataLoader | 
					
						
						|  | from sentence_transformers import LoggingHandler, util | 
					
						
						|  | from sentence_transformers.cross_encoder import CrossEncoder | 
					
						
						|  | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator | 
					
						
						|  | from sentence_transformers import InputExample | 
					
						
						|  | import logging | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | import gzip | 
					
						
						|  | import os | 
					
						
						|  | import tarfile | 
					
						
						|  | import tqdm | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(format='%(asctime)s - %(message)s', | 
					
						
						|  | datefmt='%Y-%m-%d %H:%M:%S', | 
					
						
						|  | level=logging.INFO, | 
					
						
						|  | handlers=[LoggingHandler()]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_name = 'distilroberta-base' | 
					
						
						|  | train_batch_size = 32 | 
					
						
						|  | num_epochs = 1 | 
					
						
						|  | model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pos_neg_ration = 4 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_train_samples = 2e7 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = CrossEncoder(model_name, num_labels=1, max_length=512) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data_folder = 'msmarco-data' | 
					
						
						|  | os.makedirs(data_folder, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | corpus = {} | 
					
						
						|  | collection_filepath = os.path.join(data_folder, 'collection.tsv') | 
					
						
						|  | if not os.path.exists(collection_filepath): | 
					
						
						|  | tar_filepath = os.path.join(data_folder, 'collection.tar.gz') | 
					
						
						|  | if not os.path.exists(tar_filepath): | 
					
						
						|  | logging.info("Download collection.tar.gz") | 
					
						
						|  | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) | 
					
						
						|  |  | 
					
						
						|  | with tarfile.open(tar_filepath, "r:gz") as tar: | 
					
						
						|  | tar.extractall(path=data_folder) | 
					
						
						|  |  | 
					
						
						|  | with open(collection_filepath, 'r', encoding='utf8') as fIn: | 
					
						
						|  | for line in fIn: | 
					
						
						|  | pid, passage = line.strip().split("\t") | 
					
						
						|  | corpus[pid] = passage | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | queries = {} | 
					
						
						|  | queries_filepath = os.path.join(data_folder, 'queries.train.tsv') | 
					
						
						|  | if not os.path.exists(queries_filepath): | 
					
						
						|  | tar_filepath = os.path.join(data_folder, 'queries.tar.gz') | 
					
						
						|  | if not os.path.exists(tar_filepath): | 
					
						
						|  | logging.info("Download queries.tar.gz") | 
					
						
						|  | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) | 
					
						
						|  |  | 
					
						
						|  | with tarfile.open(tar_filepath, "r:gz") as tar: | 
					
						
						|  | tar.extractall(path=data_folder) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(queries_filepath, 'r', encoding='utf8') as fIn: | 
					
						
						|  | for line in fIn: | 
					
						
						|  | qid, query = line.strip().split("\t") | 
					
						
						|  | queries[qid] = query | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_samples = [] | 
					
						
						|  | dev_samples = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_dev_queries = 200 | 
					
						
						|  | num_max_dev_negatives = 200 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz') | 
					
						
						|  | if not os.path.exists(train_eval_filepath): | 
					
						
						|  | logging.info("Download "+os.path.basename(train_eval_filepath)) | 
					
						
						|  | util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath) | 
					
						
						|  |  | 
					
						
						|  | with gzip.open(train_eval_filepath, 'rt') as fIn: | 
					
						
						|  | for line in fIn: | 
					
						
						|  | qid, pos_id, neg_id = line.strip().split() | 
					
						
						|  |  | 
					
						
						|  | if qid not in dev_samples and len(dev_samples) < num_dev_queries: | 
					
						
						|  | dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()} | 
					
						
						|  |  | 
					
						
						|  | if qid in dev_samples: | 
					
						
						|  | dev_samples[qid]['positive'].add(corpus[pos_id]) | 
					
						
						|  |  | 
					
						
						|  | if len(dev_samples[qid]['negative']) < num_max_dev_negatives: | 
					
						
						|  | dev_samples[qid]['negative'].add(corpus[neg_id]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train.tsv.gz') | 
					
						
						|  | if not os.path.exists(train_filepath): | 
					
						
						|  | logging.info("Download "+os.path.basename(train_filepath)) | 
					
						
						|  | util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz', train_filepath) | 
					
						
						|  |  | 
					
						
						|  | cnt = 0 | 
					
						
						|  | with gzip.open(train_filepath, 'rt') as fIn: | 
					
						
						|  | for line in tqdm.tqdm(fIn, unit_scale=True): | 
					
						
						|  | qid, pos_id, neg_id = line.strip().split() | 
					
						
						|  |  | 
					
						
						|  | if qid in dev_samples: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | query = queries[qid] | 
					
						
						|  | if (cnt % (pos_neg_ration+1)) == 0: | 
					
						
						|  | passage = corpus[pos_id] | 
					
						
						|  | label = 1 | 
					
						
						|  | else: | 
					
						
						|  | passage = corpus[neg_id] | 
					
						
						|  | label = 0 | 
					
						
						|  |  | 
					
						
						|  | train_samples.append(InputExample(texts=[query, passage], label=label)) | 
					
						
						|  | cnt += 1 | 
					
						
						|  |  | 
					
						
						|  | if cnt >= max_train_samples: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | evaluator = CERerankingEvaluator(dev_samples, name='train-eval') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | warmup_steps = 5000 | 
					
						
						|  | logging.info("Warmup-steps: {}".format(warmup_steps)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.fit(train_dataloader=train_dataloader, | 
					
						
						|  | evaluator=evaluator, | 
					
						
						|  | epochs=num_epochs, | 
					
						
						|  | evaluation_steps=10000, | 
					
						
						|  | warmup_steps=warmup_steps, | 
					
						
						|  | output_path=model_save_path, | 
					
						
						|  | use_amp=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.save(model_save_path+'-latest') |