Spaces:
Runtime error
Runtime error
from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet | |
from torch import optim | |
from torch.autograd import Variable | |
import json | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class GlobalDiscriminator(nn.Module): | |
def __init__(self, args, input_dim): | |
super().__init__() | |
self.l0 = nn.Linear(32, 32) | |
self.l1 = nn.Linear(32, 32) | |
self.l2 = nn.Linear(512, 1) | |
def forward(self, y, M, data): | |
adj = Variable(data['adj'].float(), requires_grad=False).cuda() | |
# h0 = Variable(data['feats'].float()).cuda() | |
batch_num_nodes = data['num_nodes'].int().numpy() | |
M, _ = self.encoder(M, adj, batch_num_nodes) | |
# h = F.relu(self.c0(M)) | |
# h = self.c1(h) | |
# h = h.view(y.shape[0], -1) | |
h = torch.cat((y, M), dim=1) | |
h = F.relu(self.l0(h)) | |
h = F.relu(self.l1(h)) | |
return self.l2(h) | |
class PriorDiscriminator(nn.Module): | |
def __init__(self, input_dim): | |
super().__init__() | |
self.l0 = nn.Linear(input_dim, input_dim) | |
self.l1 = nn.Linear(input_dim, input_dim) | |
self.l2 = nn.Linear(input_dim, 1) | |
def forward(self, x): | |
h = F.relu(self.l0(x)) | |
h = F.relu(self.l1(h)) | |
return torch.sigmoid(self.l2(h)) | |
class FF(nn.Module): | |
def __init__(self, input_dim): | |
super().__init__() | |
# self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1) | |
# self.c1 = nn.Conv1d(512, 512, kernel_size=1) | |
# self.c2 = nn.Conv1d(512, 1, kernel_size=1) | |
self.block = nn.Sequential( | |
nn.Linear(input_dim, input_dim), | |
nn.ReLU(), | |
nn.Linear(input_dim, input_dim), | |
nn.ReLU(), | |
nn.Linear(input_dim, input_dim), | |
nn.ReLU() | |
) | |
self.linear_shortcut = nn.Linear(input_dim, input_dim) | |
# self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1, stride=1, padding=0) | |
# self.c1 = nn.Conv1d(512, 512, kernel_size=1, stride=1, padding=0) | |
# self.c2 = nn.Conv1d(512, 1, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
return self.block(x) + self.linear_shortcut(x) | |