File size: 2,402 Bytes
9a41f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from utils.config import config
from utils.preprocessing import clean_text, clean_hindi
from utils.data_loader import TranslationDataset
from models.encoder import Encoder
from models.decoder import Decoder
from models.seq2seq import Seq2Seq
import pickle

def translate_sentence(sentence, model, eng_vocab, hin_vocab, device):
    model.eval()
    sentence = clean_text(sentence)
    
    # Convert to tensor
    tokens = [eng_vocab.get(word, eng_vocab['<unk>']) for word in sentence.split()]
    src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
    
    trg_indexes = [hin_vocab['<start>']]
    
    for _ in range(config.max_length):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        
        with torch.no_grad():
            output, hidden = model.decoder(trg_tensor, hidden, encoder_outputs)
        
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        
        if pred_token == hin_vocab['<end>']:
            break
    
    trg_tokens = [list(hin_vocab.keys())[list(hin_vocab.values()).index(i)] 
                 for i in trg_indexes]
    
    return ' '.join(trg_tokens[1:-1])  # Remove <start> and <end>

def main():
    # Load vocabularies
    with open('bin/eng_vocab.pkl', 'rb') as f:
        eng_vocab = pickle.load(f)
    with open('bin/hin_vocab.pkl', 'rb') as f:
        hin_vocab = pickle.load(f)
    
    # Load model
    enc = Encoder(
        len(eng_vocab), 
        config.embedding_dim, 
        config.hidden_size, 
        config.num_layers, 
        config.dropout
    ).to(config.device)
    
    dec = Decoder(
        len(hin_vocab),
        config.embedding_dim,
        config.hidden_size,
        config.num_layers,
        config.dropout
    ).to(config.device)
    
    model = Seq2Seq(enc, dec, config.device).to(config.device)
    model.load_state_dict(torch.load("bin/seq2seq.pth", map_location=config.device))
    
    # Interactive translation
    while True:
        sentence = input("Enter English sentence (type 'exit' to quit): ")
        if sentence.lower() == 'exit':
            break
        translation = translate_sentence(sentence, model, eng_vocab, hin_vocab, config.device)
        print(f"Hindi Translation: {translation}\n")

if __name__ == "__main__":
    main()