waqas56jb commited on
Commit
45b4e04
·
verified ·
1 Parent(s): d31883f

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +78 -0
train.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ from transformers import MarianMTModel, MarianTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
4
+ from datasets import load_dataset
5
+ from torch.utils.data import Dataset
6
+
7
+ # Load dataset (limit to 100 samples)
8
+ dataset = load_dataset("Helsinki-NLP/tatoeba_mt", "ara-eng")
9
+ train_data = dataset["test"].select(range(100)) # Use only first 100 samples
10
+ val_data = dataset["validation"].select(range(100))
11
+
12
+ # Load tokenizer and model
13
+ model_name = "Helsinki-NLP/opus-mt-ar-en"
14
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
15
+ model = MarianMTModel.from_pretrained(model_name)
16
+
17
+ # Custom Dataset class
18
+ class TranslationDataset(Dataset):
19
+ def __init__(self, data, tokenizer, max_length=128):
20
+ self.data = data
21
+ self.tokenizer = tokenizer
22
+ self.max_length = max_length
23
+
24
+ def __len__(self):
25
+ return len(self.data)
26
+
27
+ def __getitem__(self, idx):
28
+ src_text = self.data[idx]["sourceString"]
29
+ tgt_text = self.data[idx]["targetString"]
30
+ src_encoded = self.tokenizer(src_text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
31
+ tgt_encoded = self.tokenizer(tgt_text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
32
+ return {
33
+ "input_ids": src_encoded["input_ids"].squeeze(0),
34
+ "attention_mask": src_encoded["attention_mask"].squeeze(0),
35
+ "labels": tgt_encoded["input_ids"].squeeze(0),
36
+ }
37
+
38
+ # Create dataset instances
39
+ train_dataset = TranslationDataset(train_data, tokenizer)
40
+ val_dataset = TranslationDataset(val_data, tokenizer)
41
+
42
+ # Data collator
43
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
44
+
45
+ # Training arguments (reduce epochs & batch size)
46
+ training_args = Seq2SeqTrainingArguments(
47
+ output_dir="./results",
48
+ evaluation_strategy="epoch",
49
+ save_strategy="epoch",
50
+ per_device_train_batch_size=8, # Reduce batch size
51
+ per_device_eval_batch_size=8,
52
+ learning_rate=5e-5,
53
+ weight_decay=0.01,
54
+ num_train_epochs=2, # Reduce epochs
55
+ logging_dir="./logs",
56
+ logging_steps=5, # Log more frequently
57
+ save_total_limit=1,
58
+ predict_with_generate=True,
59
+ )
60
+
61
+ # Trainer setup
62
+ trainer = Seq2SeqTrainer(
63
+ model=model,
64
+ args=training_args,
65
+ train_dataset=train_dataset,
66
+ eval_dataset=val_dataset,
67
+ tokenizer=tokenizer,
68
+ data_collator=data_collator,
69
+ )
70
+
71
+ # Train model
72
+ trainer.train()
73
+
74
+ # Save model
75
+ with open("nmt_model.pkl", "wb") as f:
76
+ pickle.dump(model, f)
77
+
78
+ print("Model training complete and saved as nmt_model.pkl")