Redmind commited on
Commit
3b692d3
·
verified ·
1 Parent(s): 5506cd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from torch.utils.data import Dataset
4
+ from transformers import (
5
+ MBartTokenizer,
6
+ MBartForConditionalGeneration,
7
+ Trainer,
8
+ TrainingArguments,
9
+ HfFolder
10
+ )
11
+
12
+ # Save the Hugging Face token (if not already saved)
13
+ token = os.getenv("HF_TOKEN")
14
+ if token:
15
+ HfFolder.save_token(token)
16
+ print("Token saved successfully!")
17
+ else:
18
+ print("HF_TOKEN environment variable not set. Ensure your token is saved for authentication.")
19
+
20
+ # Step 1: Define Dataset Class
21
+ class HindiDataset(Dataset):
22
+ def __init__(self, data_path, tokenizer, max_length=512):
23
+ """
24
+ Dataset class for Hindi translation tasks.
25
+
26
+ Args:
27
+ data_path (str): Path to the dataset file (e.g., TSV with source-target pairs).
28
+ tokenizer (MBartTokenizer): Tokenizer for mBART.
29
+ max_length (int): Maximum sequence length for tokenization.
30
+ """
31
+ self.data = pd.read_csv(data_path, sep="\t")
32
+ self.tokenizer = tokenizer
33
+ self.max_length = max_length
34
+
35
+ def __len__(self):
36
+ return len(self.data)
37
+
38
+ def __getitem__(self, idx):
39
+ source = self.data.iloc[idx]["source"]
40
+ target = self.data.iloc[idx]["target"]
41
+
42
+ source_encodings = self.tokenizer(
43
+ source, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
44
+ )
45
+ target_encodings = self.tokenizer(
46
+ target, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
47
+ )
48
+
49
+ return {
50
+ "input_ids": source_encodings["input_ids"].squeeze(),
51
+ "attention_mask": source_encodings["attention_mask"].squeeze(),
52
+ "labels": target_encodings["input_ids"].squeeze(),
53
+ }
54
+
55
+ # Step 2: Load Tokenizer and Dataset
56
+ data_path = "hindi_dataset.tsv" # Path to your dataset file
57
+ tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50")
58
+ train_dataset = HindiDataset(data_path, tokenizer)
59
+
60
+ # Step 3: Load Pre-trained mBART Model
61
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
62
+
63
+ # Step 4: Define Training Arguments
64
+ training_args = TrainingArguments(
65
+ output_dir="./mbart-hindi", # Output directory for model checkpoints
66
+ per_device_train_batch_size=4, # Training batch size per GPU
67
+ per_device_eval_batch_size=4, # Evaluation batch size per GPU
68
+ evaluation_strategy="steps", # Evaluate every 'save_steps'
69
+ save_steps=500, # Save model every 500 steps
70
+ save_total_limit=2, # Keep only 2 checkpoints
71
+ logging_dir="./logs", # Directory for training logs
72
+ num_train_epochs=3, # Number of training epochs
73
+ learning_rate=5e-5, # Learning rate
74
+ weight_decay=0.01, # Weight decay for optimizer
75
+ report_to="none" # Disable third-party logging
76
+ )
77
+
78
+ # Step 5: Initialize Trainer
79
+ trainer = Trainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=train_dataset,
83
+ tokenizer=tokenizer
84
+ )
85
+
86
+ # Step 6: Train the Model
87
+ print("Starting training...")
88
+ trainer.train()
89
+
90
+ # Step 7: Save the Fine-Tuned Model
91
+ output_dir = "./mbart-hindi-model"
92
+ print(f"Saving fine-tuned model to {output_dir}...")
93
+ trainer.save_model(output_dir)
94
+
95
+ # Step 8: Test the Fine-Tuned Model
96
+ print("Testing the fine-tuned model...")
97
+ model = MBartForConditionalGeneration.from_pretrained(output_dir)
98
+ tokenizer = MBartTokenizer.from_pretrained(output_dir)
99
+
100
+ test_text = "Translate this to Hindi."
101
+ inputs = tokenizer(test_text, return_tensors="pt")
102
+ outputs = model.generate(**inputs)
103
+ print("Generated Translation:", tokenizer.decode(outputs[0], skip_special_tokens=True))