SentenceTransformer
/
examples
/unsupervised_learning
/CT_In-Batch_Negatives
/train_stsb_ct-improved.py
import torch | |
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator | |
from sentence_transformers import SentenceTransformer, LoggingHandler, models, util, InputExample | |
from sentence_transformers import losses | |
import os | |
import gzip | |
import csv | |
from datetime import datetime | |
import logging | |
from torch.utils.data import DataLoader | |
#### Just some code to print debug information to stdout | |
logging.basicConfig(format='%(asctime)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S', | |
level=logging.INFO, | |
handlers=[LoggingHandler()]) | |
#### /print debug information to stdout | |
## Training parameters | |
model_name = 'distilbert-base-uncased' | |
batch_size = 128 | |
epochs = 1 | |
max_seq_length = 75 | |
# Save path to store our model | |
model_save_path = 'output/training_stsb_ct-improved-{}-{}'.format(model_name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) | |
################# Train sentences ################# | |
# We use 1 Million sentences from Wikipedia to train our model | |
wikipedia_dataset_path = 'data/wiki1m_for_simcse.txt' | |
if not os.path.exists(wikipedia_dataset_path): | |
util.http_get('https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt', wikipedia_dataset_path) | |
# train_sentences are simply your list of sentences | |
train_sentences = [] | |
with open(wikipedia_dataset_path, 'r', encoding='utf8') as fIn: | |
for line in fIn: | |
train_sentences.append(InputExample(texts=[line.strip(), line.strip()])) | |
################# Download and load STSb ################# | |
data_folder = 'data/stsbenchmark' | |
sts_dataset_path = f'{data_folder}/stsbenchmark.tsv.gz' | |
if not os.path.exists(sts_dataset_path): | |
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path) | |
dev_samples = [] | |
test_samples = [] | |
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn: | |
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) | |
for row in reader: | |
score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1 | |
inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score) | |
if row['split'] == 'dev': | |
dev_samples.append(inp_example) | |
elif row['split'] == 'test': | |
test_samples.append(inp_example) | |
dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev') | |
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test') | |
################# Intialize an SBERT model ################# | |
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) | |
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) | |
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | |
# For ContrastiveTension we need a special data loader to construct batches with the desired properties | |
train_dataloader = DataLoader(train_sentences, batch_size=batch_size, shuffle=True, drop_last=True) | |
# As loss, we losses.ContrastiveTensionLoss | |
train_loss = losses.ContrastiveTensionLossInBatchNegatives(model, scale=1, similarity_fct=util.dot_score) | |
# Train the model | |
model.fit(train_objectives=[(train_dataloader, train_loss)], | |
evaluator=dev_evaluator, | |
epochs=1, | |
evaluation_steps=1000, | |
warmup_steps=1000, | |
output_path=model_save_path, | |
optimizer_params={'lr': 5e-5}, | |
use_amp=True #Set to True, if your GPU supports FP16 cores | |
) | |
########### Load the model and evaluate on test set | |
model = SentenceTransformer(model_save_path) | |
test_evaluator(model) | |