Translation
Malayalam
English
Haryni commited on
Commit
8db5b66
·
verified ·
1 Parent(s): 8b5dd38

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +104 -0
README.md CHANGED
@@ -11,3 +11,107 @@ base_model:
11
  - deepseek-ai/DeepSeek-R1
12
  pipeline_tag: translation
13
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  - deepseek-ai/DeepSeek-R1
12
  pipeline_tag: translation
13
  ---
14
+ import os
15
+ import argparse
16
+ import pandas as pd
17
+ from datasets import Dataset
18
+ from transformers import (
19
+ AutoTokenizer,
20
+ AutoModelForSeq2SeqLM,
21
+ Seq2SeqTrainingArguments,
22
+ Seq2SeqTrainer,
23
+ DataCollatorForSeq2Seq
24
+ )
25
+ from utils import compute_metrics
26
+
27
+ def load_dataset(file_path):
28
+ """Load and prepare the dataset."""
29
+ df = pd.read_csv(file_path)
30
+ dataset = Dataset.from_pandas(df)
31
+ # Split dataset into train and validation
32
+ split_dataset = dataset.train_test_split(test_size=0.1)
33
+ return split_dataset
34
+
35
+ def preprocess_function(examples, tokenizer, max_length=128):
36
+ """Tokenize the texts."""
37
+ inputs = [ex for ex in examples["english_text"]]
38
+ targets = [ex for ex in examples["malayalam_text"]]
39
+
40
+ model_inputs = tokenizer(
41
+ inputs,
42
+ max_length=max_length,
43
+ truncation=True,
44
+ padding="max_length",
45
+ )
46
+
47
+ with tokenizer.as_target_tokenizer():
48
+ labels = tokenizer(
49
+ targets,
50
+ max_length=max_length,
51
+ truncation=True,
52
+ padding="max_length",
53
+ )
54
+
55
+ model_inputs["labels"] = labels["input_ids"]
56
+ return model_inputs
57
+
58
+ def main(args):
59
+ # Load tokenizer and model
60
+ model_name = "google/mt5-small"
61
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
63
+
64
+ # Load and preprocess dataset
65
+ dataset = load_dataset("dataset/malayalam_dataset.csv")
66
+
67
+ # Tokenize datasets
68
+ tokenized_datasets = dataset.map(
69
+ lambda x: preprocess_function(x, tokenizer),
70
+ batched=True,
71
+ remove_columns=dataset["train"].column_names
72
+ )
73
+
74
+ # Define training arguments
75
+ training_args = Seq2SeqTrainingArguments(
76
+ output_dir="./model",
77
+ evaluation_strategy="epoch",
78
+ learning_rate=args.learning_rate,
79
+ per_device_train_batch_size=args.batch_size,
80
+ per_device_eval_batch_size=args.batch_size,
81
+ num_train_epochs=args.epochs,
82
+ weight_decay=0.01,
83
+ save_total_limit=2,
84
+ predict_with_generate=True,
85
+ logging_dir="./logs",
86
+ logging_steps=100,
87
+ push_to_hub=True,
88
+ )
89
+
90
+ # Create data collator
91
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
92
+
93
+ # Initialize trainer
94
+ trainer = Seq2SeqTrainer(
95
+ model=model,
96
+ args=training_args,
97
+ train_dataset=tokenized_datasets["train"],
98
+ eval_dataset=tokenized_datasets["test"],
99
+ data_collator=data_collator,
100
+ tokenizer=tokenizer,
101
+ compute_metrics=compute_metrics
102
+ )
103
+
104
+ # Train the model
105
+ trainer.train()
106
+
107
+ # Save the model
108
+ trainer.save_model("./model")
109
+ tokenizer.save_pretrained("./model")
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--epochs", type=int, default=3)
114
+ parser.add_argument("--batch_size", type=int, default=8)
115
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
116
+ args = parser.parse_args()
117
+ main(args)