File size: 2,351 Bytes
9a41f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from utils.config import config
from utils.data_loader import get_data_loaders
from models.encoder import Encoder
from models.decoder import Decoder
from models.seq2seq import Seq2Seq

def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

def train():
    train_loader, val_loader, eng_vocab, hin_vocab = get_data_loaders()

    print(f"Final English vocab size: {len(eng_vocab)}")
    print(f"Final Hindi vocab size: {len(hin_vocab)}")
    
    # Model initialization
    enc = Encoder(
        len(eng_vocab), 
        config.embedding_dim, 
        config.hidden_size, 
        config.num_layers, 
        config.dropout
    ).to(config.device)
    
    dec = Decoder(
        len(hin_vocab),
        config.embedding_dim,
        config.hidden_size,
        config.num_layers,
        config.dropout
    ).to(config.device)
    
    model = Seq2Seq(enc, dec, config.device).to(config.device)
    model.apply(init_weights)
    
    # Optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
    
    # Training loop
    for epoch in range(config.epochs):
        model.train()
        epoch_loss = 0
        
        for src, trg in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            src, trg = src.to(config.device), trg.to(config.device)
            
            optimizer.zero_grad()
            output = model(src, trg, config.teacher_forcing_ratio)
            
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            
            loss = criterion(output, trg)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch: {epoch+1}, Loss: {avg_loss:.4f}")
        
        # Save model
        torch.save(model.state_dict(), f"seq2seq_epoch_{epoch+1}.pth")

if __name__ == "__main__":
    train()