Update graph_decoder/transformer.py
Browse files
graph_decoder/transformer.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from .layers import Attention, MLP
|
4 |
from .conditions import TimestepEmbedder, ConditionEmbedder
|
5 |
-
from .diffusion_utils import PlaceHolder
|
6 |
|
7 |
def modulate(x, shift, scale):
|
8 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
@@ -98,7 +98,8 @@ class Transformer(nn.Module):
|
|
98 |
|
99 |
# X: B * N * dx, E: B * N * N * de
|
100 |
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
|
101 |
-
return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
|
|
|
102 |
|
103 |
class Block(nn.Module):
|
104 |
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
|
|
2 |
import torch.nn as nn
|
3 |
from .layers import Attention, MLP
|
4 |
from .conditions import TimestepEmbedder, ConditionEmbedder
|
5 |
+
# from .diffusion_utils import PlaceHolder
|
6 |
|
7 |
def modulate(x, shift, scale):
|
8 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
98 |
|
99 |
# X: B * N * dx, E: B * N * N * de
|
100 |
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
|
101 |
+
# return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
|
102 |
+
return X, E
|
103 |
|
104 |
class Block(nn.Module):
|
105 |
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|