nitikdias's picture
Upload 114 files
74ee63f verified
raw
history blame
12.2 kB
import os
import argparse
import pandas as pd
from datasets import Dataset
from sacrebleu.metrics import BLEU, CHRF
from peft import LoraConfig, get_peft_model
from IndicTransToolkit import IndicProcessor, IndicDataCollator
from transformers import (
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
AutoModelForSeq2SeqLM,
AutoTokenizer,
EarlyStoppingCallback,
)
bleu_metric = BLEU()
chrf_metric = CHRF()
def get_arg_parse():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
)
parser.add_argument(
"--src_lang_list",
type=str,
help="comma separated list of source languages",
)
parser.add_argument(
"--tgt_lang_list",
type=str,
help="comma separated list of target languages",
)
parser.add_argument("--data_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument("--max_steps", type=int, default=1000000)
parser.add_argument("--grad_accum_steps", type=int, default=4)
parser.add_argument("--warmup_steps", type=int, default=4000)
parser.add_argument("--warmup_ratio", type=int, default=0.0)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.98)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--print_samples", action="store_true")
parser.add_argument(
"--optimizer",
type=str,
default="adamw_torch",
choices=[
"adam_hf",
"adamw_torch",
"adamw_torch_fused",
"adamw_apex_fused",
"adafactor",
],
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="inverse_sqrt",
choices=[
"inverse_sqrt",
"linear",
"polynomial",
"cosine",
"constant",
"constant_with_warmup",
],
)
parser.add_argument("--label_smoothing", type=float, default=0.0)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--metric_for_best_model", type=str, default="eval_loss")
parser.add_argument("--greater_is_better", action="store_true")
parser.add_argument("--lora_target_modules", type=str, default="q_proj,k_proj")
parser.add_argument("--lora_dropout", type=float, default=0.1)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument(
"--report_to",
type=str,
default="none",
choices=["wandb", "tensorboard", "azure_ml", "none"],
)
parser.add_argument("--patience", type=int, default=5),
parser.add_argument("--threshold", type=float, default=1e-3)
return parser
def load_and_process_translation_dataset(
data_dir,
split="train",
tokenizer=None,
processor=None,
src_lang_list=None,
tgt_lang_list=None,
num_proc=8,
seed=42
):
complete_dataset = {
"sentence_SRC": [],
"sentence_TGT": [],
}
for src_lang in src_lang_list:
for tgt_lang in tgt_lang_list:
if src_lang == tgt_lang:
continue
src_path = os.path.join(
data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{src_lang}"
)
tgt_path = os.path.join(
data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{tgt_lang}"
)
if not os.path.exists(src_path) or not os.path.exists(tgt_path):
raise FileNotFoundError(
f"Source ({split}.{src_lang}) or Target ({split}.{tgt_lang}) file not found in {data_dir}"
)
with open(src_path, encoding="utf-8") as src_file, open(
tgt_path, encoding="utf-8"
) as tgt_file:
src_lines = src_file.readlines()
tgt_lines = tgt_file.readlines()
# Ensure both files have the same number of lines
assert len(src_lines) == len(
tgt_lines
), f"Source and Target files have different number of lines for {split}.{src_lang} and {split}.{tgt_lang}"
complete_dataset["sentence_SRC"] += processor.preprocess_batch(
src_lines, src_lang=src_lang, tgt_lang=tgt_lang, is_target=False
)
complete_dataset["sentence_TGT"] += processor.preprocess_batch(
tgt_lines, src_lang=tgt_lang, tgt_lang=src_lang, is_target=True
)
complete_dataset = Dataset.from_dict(complete_dataset).shuffle(seed=seed)
return complete_dataset.map(
lambda example: preprocess_fn(
example,
tokenizer=tokenizer
),
batched=True,
num_proc=num_proc,
)
def compute_metrics_factory(
tokenizer, metric_dict=None, print_samples=False, n_samples=10
):
def compute_metrics(eval_preds):
preds, labels = eval_preds
labels[labels == -100] = tokenizer.pad_token_id
preds[preds == -100] = tokenizer.pad_token_id
with tokenizer.as_target_tokenizer():
preds = [
x.strip()
for x in tokenizer.batch_decode(
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
]
labels = [
x.strip()
for x in tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
]
assert len(preds) == len(
labels
), "Predictions and Labels have different lengths"
df = pd.DataFrame({"Predictions": preds, "References": labels}).sample(
n=n_samples
)
if print_samples:
for pred, label in zip(df["Predictions"].values, df["References"].values):
print(f" | > Prediction: {pred}")
print(f" | > Reference: {label}\n")
return {
metric_name: metric.corpus_score(preds, [labels]).score
for (metric_name, metric) in metric_dict.items()
}
return compute_metrics
def preprocess_fn(example, tokenizer, **kwargs):
model_inputs = tokenizer(
example["sentence_SRC"], truncation=True, padding=False, max_length=256
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
example["sentence_TGT"], truncation=True, padding=False, max_length=256
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def main(args):
print(f" | > Loading {args.model} and tokenizer ...")
model = AutoModelForSeq2SeqLM.from_pretrained(
args.model,
trust_remote_code=True,
attn_implementation="eager",
dropout=args.dropout
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
processor = IndicProcessor(inference=False) # pre-process before tokenization
data_collator = IndicDataCollator(
tokenizer=tokenizer,
model=model,
padding="longest", # saves padding tokens
pad_to_multiple_of=8, # better to have it as 8 when using fp16
label_pad_token_id=-100
)
if args.data_dir is not None:
train_dataset = load_and_process_translation_dataset(
args.data_dir,
split="train",
tokenizer=tokenizer,
processor=processor,
src_lang_list=args.src_lang_list.split(","),
tgt_lang_list=args.tgt_lang_list.split(","),
)
print(f" | > Loaded train dataset from {args.data_dir}. Size: {len(train_dataset)} ...")
eval_dataset = load_and_process_translation_dataset(
args.data_dir,
split="dev",
tokenizer=tokenizer,
processor=processor,
src_lang_list=args.src_lang_list.split(","),
tgt_lang_list=args.tgt_lang_list.split(","),
)
print(f" | > Loaded eval dataset from {args.data_dir}. Size: {len(eval_dataset)} ...")
else:
raise ValueError(" | > Data directory not provided")
lora_config = LoraConfig(
r=args.lora_r,
bias="none",
inference_mode=False,
task_type="SEQ_2_SEQ_LM",
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules.split(","),
)
model.set_label_smoothing(args.label_smoothing)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print(f" | > Loading metrics factory with BLEU and chrF ...")
seq2seq_compute_metrics = compute_metrics_factory(
tokenizer=tokenizer,
print_samples=args.print_samples,
metric_dict={"BLEU": bleu_metric, "chrF": chrf_metric},
)
training_args = Seq2SeqTrainingArguments(
output_dir=args.output_dir,
do_train=True,
do_eval=True,
fp16=True, # use fp16 for faster training
logging_strategy="steps",
evaluation_strategy="steps",
save_strategy="steps",
logging_steps=100,
save_total_limit=1,
predict_with_generate=True,
load_best_model_at_end=True,
max_steps=args.max_steps, # max_steps overrides num_train_epochs
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum_steps,
eval_accumulation_steps=args.grad_accum_steps,
weight_decay=args.weight_decay,
adam_beta1=args.adam_beta1,
adam_beta2=args.adam_beta2,
max_grad_norm=args.max_grad_norm,
optim=args.optimizer,
lr_scheduler_type=args.lr_scheduler,
warmup_ratio=args.warmup_ratio,
warmup_steps=args.warmup_steps,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
dataloader_num_workers=args.num_workers,
metric_for_best_model=args.metric_for_best_model,
greater_is_better=args.greater_is_better,
report_to=args.report_to,
generation_max_length=256,
generation_num_beams=5,
sortish_sampler=True,
group_by_length=True,
include_tokens_per_second=True,
include_num_input_tokens_seen=True,
dataloader_prefetch_factor=2,
)
# Create Trainer instance
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=seq2seq_compute_metrics,
callbacks=[
EarlyStoppingCallback(
early_stopping_patience=args.patience,
early_stopping_threshold=args.threshold,
)
],
)
print(f" | > Starting training ...")
try:
trainer.train()
except KeyboardInterrupt:
print(f" | > Training interrupted ...")
# this will only save the LoRA adapter weights
model.save_pretrained(args.output_dir)
if __name__ == "__main__":
parser = get_arg_parse()
args = parser.parse_args()
main(args)