Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
from torch.utils.data import Dataset | |
from transformers import ( | |
MBartTokenizer, | |
MBartForConditionalGeneration, | |
Trainer, | |
TrainingArguments, | |
) | |
from huggingface_hub import HfFolder | |
# Save the Hugging Face token (if not already saved) | |
token = os.getenv("HF_TOKEN") | |
if token: | |
HfFolder.save_token(token) | |
print("Token saved successfully!") | |
else: | |
print("HF_TOKEN environment variable not set. Ensure your token is saved for authentication.") | |
# Step 1: Define Dataset Class | |
class HindiDataset(Dataset): | |
def __init__(self, data_path, tokenizer, max_length=512): | |
""" | |
Dataset class for Hindi translation tasks. | |
Args: | |
data_path (str): Path to the dataset file (e.g., TSV with source-target pairs). | |
tokenizer (MBartTokenizer): Tokenizer for mBART. | |
max_length (int): Maximum sequence length for tokenization. | |
""" | |
self.data = pd.read_csv(data_path, sep="\t") | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
source = self.data.iloc[idx]["source"] | |
target = self.data.iloc[idx]["target"] | |
source_encodings = self.tokenizer( | |
source, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt" | |
) | |
target_encodings = self.tokenizer( | |
target, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt" | |
) | |
return { | |
"input_ids": source_encodings["input_ids"].squeeze(), | |
"attention_mask": source_encodings["attention_mask"].squeeze(), | |
"labels": target_encodings["input_ids"].squeeze(), | |
} | |
# Step 2: Load Tokenizer and Dataset | |
data_path = "hindi_dataset.tsv" # Path to your dataset file | |
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50") | |
train_dataset = HindiDataset(data_path, tokenizer) | |
# Step 3: Load Pre-trained mBART Model | |
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
# Step 4: Define Training Arguments | |
training_args = TrainingArguments( | |
output_dir="./mbart-hindi", # Output directory for model checkpoints | |
per_device_train_batch_size=4, # Training batch size per GPU | |
per_device_eval_batch_size=4, # Evaluation batch size per GPU | |
evaluation_strategy="steps", # Evaluate every 'save_steps' | |
save_steps=500, # Save model every 500 steps | |
save_total_limit=2, # Keep only 2 checkpoints | |
logging_dir="./logs", # Directory for training logs | |
num_train_epochs=3, # Number of training epochs | |
learning_rate=5e-5, # Learning rate | |
weight_decay=0.01, # Weight decay for optimizer | |
report_to="none" # Disable third-party logging | |
) | |
# Step 5: Initialize Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
tokenizer=tokenizer | |
) | |
# Step 6: Train the Model | |
print("Starting training...") | |
trainer.train() | |
# Step 7: Save the Fine-Tuned Model | |
output_dir = "./mbart-hindi-model" | |
print(f"Saving fine-tuned model to {output_dir}...") | |
trainer.save_model(output_dir) | |
# Step 8: Test the Fine-Tuned Model | |
print("Testing the fine-tuned model...") | |
model = MBartForConditionalGeneration.from_pretrained(output_dir) | |
tokenizer = MBartTokenizer.from_pretrained(output_dir) | |
test_text = "Translate this to Hindi." | |
inputs = tokenizer(test_text, return_tensors="pt") | |
outputs = model.generate(**inputs) | |
print("Generated Translation:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | |