tomaarsen HF staff commited on
Commit
975224a
·
verified ·
1 Parent(s): 585e6d3

Create train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +171 -0
train_script.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+ from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
9
+ from sentence_transformers.cross_encoder.evaluation import (
10
+ CrossEncoderNanoBEIREvaluator,
11
+ CrossEncoderRerankingEvaluator,
12
+ )
13
+ from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
14
+ from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
15
+ from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
16
+ from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
17
+ from sentence_transformers.util import mine_hard_negatives
18
+
19
+ # Set the log level to INFO to get more information
20
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
21
+
22
+
23
+ def main():
24
+ model_name = "prajjwal1/bert-tiny"
25
+
26
+ train_batch_size = 2048
27
+ num_epochs = 1
28
+ num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
29
+
30
+ # 1a. Load a model to finetune with 1b. (Optional) model card data
31
+ model = CrossEncoder(
32
+ model_name,
33
+ model_card_data=CrossEncoderModelCardData(
34
+ language="en",
35
+ license="apache-2.0",
36
+ model_name="BERT-tiny trained on GooAQ",
37
+ ),
38
+ )
39
+ print("Model max length:", model.max_length)
40
+ print("Model num labels:", model.num_labels)
41
+
42
+ # 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
43
+ logging.info("Read the gooaq training dataset")
44
+ full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
45
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
46
+ train_dataset = dataset_dict["train"]
47
+ eval_dataset = dataset_dict["test"]
48
+ logging.info(train_dataset)
49
+ logging.info(eval_dataset)
50
+
51
+ # 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
52
+ embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
53
+ hard_train_dataset = mine_hard_negatives(
54
+ train_dataset,
55
+ embedding_model,
56
+ num_negatives=num_hard_negatives, # How many negatives per question-answer pair
57
+ margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
58
+ range_min=0, # Skip the x most similar samples
59
+ range_max=100, # Consider only the x most similar samples
60
+ sampling_strategy="top", # Randomly sample negatives from the range
61
+ batch_size=4096, # Use a batch size of 4096 for the embedding model
62
+ output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
63
+ use_faiss=True,
64
+ )
65
+ logging.info(hard_train_dataset)
66
+
67
+ # 2c. (Optionally) Save the hard training dataset to disk
68
+ # hard_train_dataset.save_to_disk("gooaq-hard-train")
69
+ # Load again with:
70
+ # hard_train_dataset = load_from_disk("gooaq-hard-train")
71
+
72
+ # 3. Define our training loss.
73
+ # pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
74
+ loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
75
+
76
+ # 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
77
+ nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
78
+ dataset_names=["msmarco", "nfcorpus", "nq"],
79
+ batch_size=train_batch_size,
80
+ )
81
+
82
+ # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
83
+ # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
84
+ # embedding model as a baseline.
85
+ hard_eval_dataset = mine_hard_negatives(
86
+ eval_dataset,
87
+ embedding_model,
88
+ corpus=full_dataset["answer"], # Use the full dataset as the corpus
89
+ num_negatives=30, # How many documents to rerank
90
+ batch_size=4096,
91
+ disqualify_positives=False,
92
+ output_format="n-tuple",
93
+ use_faiss=True,
94
+ )
95
+ logging.info(hard_eval_dataset)
96
+ reranking_evaluator = CrossEncoderRerankingEvaluator(
97
+ samples=[
98
+ {
99
+ "query": sample["question"],
100
+ "positive": [sample["answer"]],
101
+ "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
102
+ }
103
+ for sample in hard_eval_dataset
104
+ ],
105
+ batch_size=train_batch_size,
106
+ name="gooaq-dev",
107
+ )
108
+
109
+ # 4c. Combine the evaluators & run the base model on them
110
+ evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
111
+ evaluator(model)
112
+
113
+ # 5. Define the training arguments
114
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
115
+ run_name = f"reranker-{short_model_name}-gooaq-bce"
116
+ args = CrossEncoderTrainingArguments(
117
+ # Required parameter:
118
+ output_dir=f"models/{run_name}",
119
+ # Optional training parameters:
120
+ num_train_epochs=num_epochs,
121
+ per_device_train_batch_size=train_batch_size,
122
+ per_device_eval_batch_size=train_batch_size,
123
+ learning_rate=5e-4,
124
+ warmup_ratio=0.1,
125
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
126
+ bf16=True, # Set to True if you have a GPU that supports BF16
127
+ load_best_model_at_end=True,
128
+ metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
129
+ # Optional tracking/debugging parameters:
130
+ eval_strategy="steps",
131
+ eval_steps=20,
132
+ save_strategy="steps",
133
+ save_steps=20,
134
+ save_total_limit=2,
135
+ logging_steps=20,
136
+ logging_first_step=True,
137
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
138
+ seed=12,
139
+ )
140
+
141
+ # 6. Create the trainer & start training
142
+ trainer = CrossEncoderTrainer(
143
+ model=model,
144
+ args=args,
145
+ train_dataset=hard_train_dataset,
146
+ loss=loss,
147
+ evaluator=evaluator,
148
+ )
149
+ trainer.train()
150
+
151
+ # 7. Evaluate the final model, useful to include these in the model card
152
+ evaluator(model)
153
+
154
+ # 8. Save the final model
155
+ final_output_dir = f"models/{run_name}/final"
156
+ model.save_pretrained(final_output_dir)
157
+
158
+ # 9. (Optional) save the model to the Hugging Face Hub!
159
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
160
+ try:
161
+ model.push_to_hub(f"cross-encoder-testing/{run_name}")
162
+ except Exception:
163
+ logging.error(
164
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
165
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
166
+ f"and saving it using `model.push_to_hub('{run_name}')`."
167
+ )
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()