File size: 1,555 Bytes
da9ac04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# train.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from super_large_language_model import TransformerModel

class TextDataset(Dataset):
    def __init__(self, texts, vocab):
        self.texts = texts
        self.vocab = vocab

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        text_indices = [self.vocab[char] for char in text]
        return torch.tensor(text_indices)

def train_model(model, dataset, num_epochs=10, batch_size=32, learning_rate=0.001):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        for batch in dataloader:
            optimizer.zero_grad()
            output = model(batch[:-1], batch[1:])
            loss = criterion(output.view(-1, output.size(-1)), batch[1:].view(-1))
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

if __name__ == "__main__":
    # Example texts and vocabulary
    texts = ["hello world", "pytorch is great"]
    vocab = {char: idx for idx, char in enumerate(set("".join(texts)))}
    
    dataset = TextDataset(texts, vocab)
    model = TransformerModel(vocab_size=len(vocab), d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048)

    train_model(model, dataset)