File size: 3,030 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
"""
This file loads sentences from a provided text file. It is expected, that the there is one sentence per line in that text file.
SimCSE will be training using these sentences. Checkpoints are stored every 500 steps to the output folder.
Usage:
python train_simcse_from_file.py path/to/sentences.txt
"""
from torch.utils.data import DataLoader
import math
from sentence_transformers import models, losses
from sentence_transformers import LoggingHandler, SentenceTransformer, InputExample
import logging
from datetime import datetime
import gzip
import sys
import tqdm
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
# Training parameters
model_name = 'distilroberta-base'
train_batch_size = 128
max_seq_length = 32
num_epochs = 1
#Input file path (a text file, each line a sentence)
if len(sys.argv) < 2:
print("Run this script with: python {} path/to/sentences.txt".format(sys.argv[0]))
exit()
filepath = sys.argv[1]
# Save path to store our model
output_name = ''
if len(sys.argv) >= 3:
output_name = "-"+sys.argv[2].replace(" ", "_").replace("/", "_").replace("\\", "_")
model_output_path = 'output/train_simcse{}-{}'.format(output_name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
################# Read the train corpus #################
train_samples = []
with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, encoding='utf8') as fIn:
for line in tqdm.tqdm(fIn, desc='Read file'):
line = line.strip()
if len(line) >= 10:
train_samples.append(InputExample(texts=[line, line]))
logging.info("Train sentences: {}".format(len(train_samples)))
# We train our model using the MultipleNegativesRankingLoss
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size, drop_last=True)
train_loss = losses.MultipleNegativesRankingLoss(model)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=warmup_steps,
optimizer_params={'lr': 5e-5},
checkpoint_path=model_output_path,
show_progress_bar=True,
use_amp=False # Set to True, if your GPU supports FP16 cores
)
|