Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
4.65 kB
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import MessagePassing, SAGEConv
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
import pandas as pd
target_dataset = 'ogbn-arxiv'
dataset = PygNodePropPredDataset(name=target_dataset, root='networks')
data = dataset[0]
split_idx = dataset.get_idx_split()
train_idx = split_idx['train']
valid_idx = split_idx['valid']
test_idx = split_idx['test']
train_loader = NeighborLoader(data, input_nodes=train_idx,
shuffle=True, num_workers=1,
batch_size=1024, num_neighbors=[30] * 2)
total_loader = NeighborLoader(data, input_nodes=None, num_neighbors=[-1],
batch_size=4096, shuffle=False,
num_workers=1)
class MLP(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
dropout):
super(MLP, self).__init__()
self.lins = torch.nn.ModuleList()
self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
self.dropout = dropout
def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, x):
for i, lin in enumerate(self.lins[:-1]):
x = lin(x)
x = self.bns[i](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[-1](x)
return torch.log_softmax(x, dim=-1)
def inference(self, total_loader, device):
xs = []
for batch in total_loader:
out = self.forward(batch.x.to(device))
out = out[:batch.batch_size]
xs.append(out.cpu())
out_all = torch.cat(xs, dim=0)
return out_all
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SAGE(data.x.shape[1], 256, dataset.num_classes, n_layers=2)
model = MLP(data.x.size(-1), hidden_channels=16, out_channels = 172, num_layers=2,
dropout = 0).to(device)
model.to(device)
epochs = 4
optimizer = torch.optim.Adam(model.parameters(), lr=1)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=7)
def test(model, device):
evaluator = Evaluator(name=target_dataset)
model.eval()
out = model.inference(total_loader, device)
y_true = data.y.cpu()
y_pred = out.argmax(dim=-1, keepdim=True)
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['acc']
val_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
return train_acc, val_acc, test_acc
for epoch in range(epochs):
model.train()
pbar = tqdm(total=train_idx.size(0))
pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = 0
for batch in train_loader:
batch_size = batch.batch_size
optimizer.zero_grad()
out = model(batch.x.to(device))
out = out[:batch_size]
batch_y = batch.y[:batch_size].to(device)
batch_y = torch.reshape(batch_y, (-1,))
loss = F.nll_loss(out, batch_y)
loss.backward()
optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(batch_y).sum())
pbar.update(batch.batch_size)
pbar.close()
loss = total_loss / len(train_loader)
approx_acc = total_correct / train_idx.size(0)
train_acc, val_acc, test_acc = test(model, device)
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}')
evaluator = Evaluator(name=target_dataset)
model.eval()
out = model.inference(total_loader, device)
y_pred = out.argmax(dim=-1, keepdim=True)
y_pred_np = y_pred[split_idx['test']].numpy()
df = pd.DataFrame(y_pred_np)
df.to_csv("submission.csv",index=False)