Redmind's picture
Update app.py
d08801e verified
raw
history blame
3.77 kB
import os
import pandas as pd
from torch.utils.data import Dataset
from transformers import (
MBartTokenizer,
MBartForConditionalGeneration,
Trainer,
TrainingArguments,
)
from huggingface_hub import HfFolder
# Save the Hugging Face token (if not already saved)
token = os.getenv("HF_TOKEN")
if token:
HfFolder.save_token(token)
print("Token saved successfully!")
else:
print("HF_TOKEN environment variable not set. Ensure your token is saved for authentication.")
# Step 1: Define Dataset Class
class HindiDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
"""
Dataset class for Hindi translation tasks.
Args:
data_path (str): Path to the dataset file (e.g., TSV with source-target pairs).
tokenizer (MBartTokenizer): Tokenizer for mBART.
max_length (int): Maximum sequence length for tokenization.
"""
self.data = pd.read_csv(data_path, sep="\t")
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
source = self.data.iloc[idx]["source"]
target = self.data.iloc[idx]["target"]
source_encodings = self.tokenizer(
source, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
)
target_encodings = self.tokenizer(
target, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
)
return {
"input_ids": source_encodings["input_ids"].squeeze(),
"attention_mask": source_encodings["attention_mask"].squeeze(),
"labels": target_encodings["input_ids"].squeeze(),
}
# Step 2: Load Tokenizer and Dataset
data_path = "hindi_dataset.tsv" # Path to your dataset file
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50")
train_dataset = HindiDataset(data_path, tokenizer)
# Step 3: Load Pre-trained mBART Model
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
# Step 4: Define Training Arguments
training_args = TrainingArguments(
output_dir="./mbart-hindi", # Output directory for model checkpoints
per_device_train_batch_size=4, # Training batch size per GPU
per_device_eval_batch_size=4, # Evaluation batch size per GPU
evaluation_strategy="steps", # Evaluate every 'save_steps'
save_steps=500, # Save model every 500 steps
save_total_limit=2, # Keep only 2 checkpoints
logging_dir="./logs", # Directory for training logs
num_train_epochs=3, # Number of training epochs
learning_rate=5e-5, # Learning rate
weight_decay=0.01, # Weight decay for optimizer
report_to="none" # Disable third-party logging
)
# Step 5: Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer
)
# Step 6: Train the Model
print("Starting training...")
trainer.train()
# Step 7: Save the Fine-Tuned Model
output_dir = "./mbart-hindi-model"
print(f"Saving fine-tuned model to {output_dir}...")
trainer.save_model(output_dir)
# Step 8: Test the Fine-Tuned Model
print("Testing the fine-tuned model...")
model = MBartForConditionalGeneration.from_pretrained(output_dir)
tokenizer = MBartTokenizer.from_pretrained(output_dir)
test_text = "Translate this to Hindi."
inputs = tokenizer(test_text, return_tensors="pt")
outputs = model.generate(**inputs)
print("Generated Translation:", tokenizer.decode(outputs[0], skip_special_tokens=True))