liuganghuggingface commited on
Commit
8c463a9
·
verified ·
1 Parent(s): e5bad6e

Update graph_decoder/transformer.py

Browse files
Files changed (1) hide show
  1. graph_decoder/transformer.py +3 -2
graph_decoder/transformer.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from .layers import Attention, MLP
4
  from .conditions import TimestepEmbedder, ConditionEmbedder
5
- from .diffusion_utils import PlaceHolder
6
 
7
  def modulate(x, shift, scale):
8
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -98,7 +98,8 @@ class Transformer(nn.Module):
98
 
99
  # X: B * N * dx, E: B * N * N * de
100
  X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
101
- return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
 
102
 
103
  class Block(nn.Module):
104
  def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
 
2
  import torch.nn as nn
3
  from .layers import Attention, MLP
4
  from .conditions import TimestepEmbedder, ConditionEmbedder
5
+ # from .diffusion_utils import PlaceHolder
6
 
7
  def modulate(x, shift, scale):
8
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
 
98
 
99
  # X: B * N * dx, E: B * N * N * de
100
  X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
101
+ # return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
102
+ return X, E
103
 
104
  class Block(nn.Module):
105
  def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):