Luna OpenLabs commited on
Commit
5ed5e0b
·
verified ·
1 Parent(s): 5a1602b

Create training/train.py

Browse files
Files changed (1) hide show
  1. training/train.py +56 -0
training/train.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/train.py
2
+ import pandas as pd
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from transformers import BertTokenizer
5
+ from model.luna_model import LunaAI
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import AdamW
9
+
10
+ class TextDataset(Dataset):
11
+ def __init__(self, csv_file):
12
+ self.data = pd.read_csv(csv_file)
13
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ text = self.data.iloc[idx, 0]
20
+ label = self.data.iloc[idx, 1]
21
+ encoding = self.tokenizer.encode_plus(
22
+ text,
23
+ add_special_tokens=True,
24
+ return_tensors='pt',
25
+ padding='max_length',
26
+ max_length=128,
27
+ truncation=True,
28
+ )
29
+ return {
30
+ 'input_ids': encoding['input_ids'].flatten(),
31
+ 'attention_mask': encoding['attention_mask'].flatten(),
32
+ 'labels': torch.tensor(label, dtype=torch.long),
33
+ }
34
+
35
+ def train_model(model, dataset):
36
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
37
+ optimizer = AdamW(model.parameters(), lr=5e-5)
38
+
39
+ model.train()
40
+ for epoch in range(3): # Adjust the number of epochs
41
+ for batch in dataloader:
42
+ input_ids = batch['input_ids']
43
+ attention_mask = batch['attention_mask']
44
+ labels = batch['labels']
45
+
46
+ optimizer.zero_grad()
47
+ outputs = model(input_ids, attention_mask)
48
+ loss = nn.CrossEntropyLoss()(outputs, labels)
49
+ loss.backward()
50
+ optimizer.step()
51
+ print(f'Epoch {epoch}, Loss: {loss.item()}')
52
+
53
+ if __name__ == "__main__":
54
+ dataset = TextDataset('data/dataset.csv')
55
+ model = LunaAI()
56
+ train_model(model, dataset)