|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import yaml |
|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from . import diffusion_utils as utils |
|
from .molecule_utils import graph_to_smiles, check_valid |
|
from .transformer import Transformer |
|
|
|
class GraphDiT(nn.Module): |
|
def __init__( |
|
self, |
|
model_config_path, |
|
data_info_path, |
|
model_dtype, |
|
): |
|
super().__init__() |
|
|
|
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) |
|
|
|
input_dims = data_info.input_dims |
|
output_dims = data_info.output_dims |
|
nodes_dist = data_info.nodes_dist |
|
active_index = data_info.active_index |
|
|
|
self.model_config = dm_cfg |
|
self.data_info = data_info |
|
self.T = dm_cfg.diffusion_steps |
|
self.guide_scale = dm_cfg.guide_scale |
|
self.Xdim = input_dims["X"] |
|
self.Edim = input_dims["E"] |
|
self.ydim = input_dims["y"] |
|
self.Xdim_output = output_dims["X"] |
|
self.Edim_output = output_dims["E"] |
|
self.ydim_output = output_dims["y"] |
|
self.node_dist = nodes_dist |
|
self.active_index = active_index |
|
self.max_n_nodes = data_info.max_n_nodes |
|
self.train_loss = TrainLossDiscrete(dm_cfg.lambda_train) |
|
self.atom_decoder = data_info.atom_decoder |
|
self.hidden_size = dm_cfg.hidden_size |
|
self.text_input_size = 768 |
|
|
|
self.denoiser = Transformer( |
|
max_n_nodes=self.max_n_nodes, |
|
hidden_size=dm_cfg.hidden_size, |
|
depth=dm_cfg.depth, |
|
num_heads=dm_cfg.num_heads, |
|
mlp_ratio=dm_cfg.mlp_ratio, |
|
drop_condition=dm_cfg.drop_condition, |
|
Xdim=self.Xdim, |
|
Edim=self.Edim, |
|
ydim=self.ydim, |
|
text_dim=self.text_input_size |
|
) |
|
self.model_dtype = model_dtype |
|
|
|
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( |
|
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps |
|
) |
|
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( |
|
data_info.node_types.to(self.model_dtype) |
|
) |
|
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( |
|
data_info.edge_types.to(self.model_dtype) |
|
) |
|
x_marginals = x_marginals / x_marginals.sum() |
|
e_marginals = e_marginals / e_marginals.sum() |
|
|
|
xe_conditions = data_info.transition_E.to(self.model_dtype) |
|
xe_conditions = xe_conditions[self.active_index][:, self.active_index] |
|
|
|
xe_conditions = xe_conditions.sum(dim=1) |
|
ex_conditions = xe_conditions.t() |
|
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) |
|
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) |
|
|
|
self.transition_model = utils.MarginalTransition( |
|
x_marginals=x_marginals, |
|
e_marginals=e_marginals, |
|
xe_conditions=xe_conditions, |
|
ex_conditions=ex_conditions, |
|
y_classes=self.ydim_output, |
|
n_nodes=self.max_n_nodes, |
|
) |
|
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) |
|
|
|
def init_model(self, model_dir, verbose=False): |
|
model_file = os.path.join(model_dir, 'model.pt') |
|
if os.path.exists(model_file): |
|
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) |
|
else: |
|
raise FileNotFoundError(f"Model file not found: {model_file}") |
|
|
|
if verbose: |
|
print('GraphDiT Denoiser Model initialized.') |
|
print('Denoiser model:\n', self.denoiser) |
|
|
|
def save_pretrained(self, output_dir): |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
model_path = os.path.join(output_dir, 'model.pt') |
|
torch.save(self.denoiser.state_dict(), model_path) |
|
|
|
|
|
config_path = os.path.join(output_dir, 'model_config.yaml') |
|
with open(config_path, 'w') as f: |
|
yaml.dump(vars(self.model_config), f) |
|
|
|
|
|
data_info_path = os.path.join(output_dir, 'data.meta.json') |
|
data_info_dict = { |
|
"active_atoms": self.data_info.active_atoms, |
|
"max_node": self.data_info.max_n_nodes, |
|
"n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(), |
|
"bond_type_dist": self.data_info.edge_types.tolist(), |
|
"transition_E": self.data_info.transition_E.tolist(), |
|
"atom_type_dist": self.data_info.node_types.tolist(), |
|
"valencies": self.data_info.valency_distribution.tolist() |
|
} |
|
with open(data_info_path, 'w') as f: |
|
json.dump(data_info_dict, f, indent=2) |
|
|
|
print('GraphDiT Model and configurations saved to:', output_dir) |
|
|
|
def disable_grads(self): |
|
self.denoiser.disable_grads() |
|
|
|
def forward( |
|
self, x, edge_index, edge_attr, graph_batch, properties, text_embedding, no_label_index |
|
): |
|
properties = torch.where(properties == no_label_index, float("nan"), properties) |
|
data_x = F.one_hot(x, num_classes=118).to(self.model_dtype)[ |
|
:, self.active_index |
|
] |
|
data_edge_attr = F.one_hot(edge_attr, num_classes=5).to(self.model_dtype) |
|
|
|
dense_data, node_mask = utils.to_dense( |
|
data_x, edge_index, data_edge_attr, graph_batch, self.max_n_nodes |
|
) |
|
X, E = dense_data.X, dense_data.E |
|
|
|
dense_data = dense_data.mask(node_mask) |
|
noisy_data = self.apply_noise(X, E, properties, node_mask) |
|
pred = self._forward(noisy_data, text_embedding) |
|
loss = self.train_loss( |
|
masked_pred_X=pred.X, |
|
masked_pred_E=pred.E, |
|
true_X=X, |
|
true_E=E, |
|
node_mask=node_mask, |
|
) |
|
return loss |
|
|
|
def _forward(self, noisy_data, text_embedding, unconditioned=False): |
|
noisy_x, noisy_e, properties = ( |
|
noisy_data["X_t"].to(self.model_dtype), |
|
noisy_data["E_t"].to(self.model_dtype), |
|
noisy_data["y_t"].to(self.model_dtype).clone(), |
|
) |
|
node_mask, timestep, text_embedding = ( |
|
noisy_data["node_mask"], |
|
noisy_data["t"], |
|
text_embedding.to(self.model_dtype), |
|
) |
|
|
|
pred = self.denoiser( |
|
noisy_x, |
|
noisy_e, |
|
node_mask, |
|
properties, |
|
text_embedding, |
|
timestep, |
|
unconditioned=unconditioned, |
|
) |
|
return pred |
|
|
|
def apply_noise(self, X, E, y, node_mask): |
|
"""Sample noise and apply it to the data.""" |
|
|
|
|
|
|
|
lowest_t = 0 if self.training else 1 |
|
t_int = torch.randint( |
|
lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device |
|
).to( |
|
self.model_dtype |
|
) |
|
s_int = t_int - 1 |
|
|
|
t_float = t_int / self.T |
|
s_float = s_int / self.T |
|
|
|
|
|
beta_t = self.noise_schedule(t_normalized=t_float) |
|
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) |
|
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) |
|
|
|
Qtb = self.transition_model.get_Qt_bar( |
|
alpha_t_bar, X.device |
|
) |
|
|
|
bs, n, d = X.shape |
|
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) |
|
prob_all = X_all @ Qtb.X |
|
probX = prob_all[:, :, : self.Xdim_output] |
|
probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) |
|
|
|
sampled_t = utils.sample_discrete_features( |
|
probX=probX, probE=probE, node_mask=node_mask |
|
) |
|
|
|
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) |
|
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) |
|
assert (X.shape == X_t.shape) and (E.shape == E_t.shape) |
|
|
|
y_t = y |
|
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) |
|
|
|
noisy_data = { |
|
"t_int": t_int, |
|
"t": t_float, |
|
"beta_t": beta_t, |
|
"alpha_s_bar": alpha_s_bar, |
|
"alpha_t_bar": alpha_t_bar, |
|
"X_t": z_t.X, |
|
"E_t": z_t.E, |
|
"y_t": z_t.y, |
|
"node_mask": node_mask, |
|
} |
|
return noisy_data |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
properties, |
|
text_embedding, |
|
no_label_index, |
|
): |
|
properties = torch.where(properties == no_label_index, float("nan"), properties) |
|
batch_size = properties.size(0) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
n_nodes = self.node_dist.sample_n(batch_size, device) |
|
arange = ( |
|
torch.arange(self.max_n_nodes, device=device) |
|
.unsqueeze(0) |
|
.expand(batch_size, -1) |
|
) |
|
node_mask = arange < n_nodes.unsqueeze(1) |
|
|
|
z_T = utils.sample_discrete_feature_noise( |
|
limit_dist=self.limit_dist, node_mask=node_mask |
|
) |
|
X, E = z_T.X, z_T.E |
|
|
|
assert (E == torch.transpose(E, 1, 2)).all() |
|
|
|
|
|
y = properties |
|
for s_int in reversed(range(0, self.T)): |
|
s_array = s_int * torch.ones((batch_size, 1)).type_as(y) |
|
t_array = s_array + 1 |
|
s_norm = s_array / self.T |
|
t_norm = t_array / self.T |
|
|
|
|
|
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( |
|
s_norm, t_norm, X, E, y, text_embedding, node_mask |
|
) |
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |
|
|
|
|
|
sampled_s = sampled_s.mask(node_mask, collapse=True) |
|
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |
|
|
|
molecule_list = [] |
|
for i in range(batch_size): |
|
n = n_nodes[i] |
|
atom_types = X[i, :n].cpu() |
|
edge_types = E[i, :n, :n].cpu() |
|
molecule_list.append([atom_types, edge_types]) |
|
|
|
smiles_list = graph_to_smiles(molecule_list, self.atom_decoder) |
|
|
|
return smiles_list |
|
|
|
def check_valid(self, smiles): |
|
return check_valid(smiles) |
|
|
|
def sample_p_zs_given_zt( |
|
self, s, t, X_t, E_t, properties, text_embedding, node_mask |
|
): |
|
"""Samples from zs ~ p(zs | zt). Only used during sampling. |
|
if last_step, return the graph prediction as well""" |
|
bs, n, _ = X_t.shape |
|
beta_t = self.noise_schedule(t_normalized=t) |
|
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) |
|
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) |
|
|
|
|
|
noisy_data = { |
|
"X_t": X_t, |
|
"E_t": E_t, |
|
"y_t": properties, |
|
"t": t, |
|
"node_mask": node_mask, |
|
} |
|
|
|
def get_prob(noisy_data, text_embedding, unconditioned=False): |
|
pred = self._forward(noisy_data, text_embedding, unconditioned=unconditioned) |
|
|
|
|
|
pred_X = F.softmax(pred.X, dim=-1) |
|
pred_E = F.softmax(pred.E, dim=-1) |
|
|
|
device = text_embedding.device |
|
|
|
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) |
|
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) |
|
Qt = self.transition_model.get_Qt(beta_t, device) |
|
|
|
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) |
|
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) |
|
|
|
unnormalized_probX_all = utils.reverse_diffusion( |
|
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X |
|
) |
|
|
|
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] |
|
unnormalized_prob_E = unnormalized_probX_all[ |
|
:, :, self.Xdim_output : |
|
].reshape(bs, n * n, -1) |
|
|
|
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 |
|
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 |
|
|
|
prob_X = unnormalized_prob_X / torch.sum( |
|
unnormalized_prob_X, dim=-1, keepdim=True |
|
) |
|
prob_E = unnormalized_prob_E / torch.sum( |
|
unnormalized_prob_E, dim=-1, keepdim=True |
|
) |
|
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) |
|
|
|
return prob_X, prob_E |
|
|
|
prob_X, prob_E = get_prob(noisy_data, text_embedding) |
|
|
|
|
|
if self.guide_scale is not None and self.guide_scale != 1: |
|
uncon_prob_X, uncon_prob_E = get_prob( |
|
noisy_data, text_embedding, unconditioned=True |
|
) |
|
prob_X = ( |
|
uncon_prob_X |
|
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale |
|
) |
|
prob_E = ( |
|
uncon_prob_E |
|
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale |
|
) |
|
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) |
|
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) |
|
|
|
sampled_s = utils.sample_discrete_features( |
|
prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() |
|
) |
|
|
|
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) |
|
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) |
|
|
|
assert (E_s == torch.transpose(E_s, 1, 2)).all() |
|
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) |
|
|
|
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) |
|
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) |
|
|
|
return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( |
|
node_mask, collapse=True |
|
).type_as(properties) |
|
|
|
|
|
class TrainLossDiscrete(nn.Module): |
|
"""Train with Cross entropy""" |
|
|
|
def __init__(self, lambda_train): |
|
super().__init__() |
|
self.lambda_train = lambda_train |
|
|
|
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, node_mask): |
|
|
|
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) |
|
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) |
|
masked_pred_X = torch.reshape( |
|
masked_pred_X, (-1, masked_pred_X.size(-1)) |
|
) |
|
masked_pred_E = torch.reshape( |
|
masked_pred_E, (-1, masked_pred_E.size(-1)) |
|
) |
|
|
|
|
|
mask_X = (true_X != 0.0).any(dim=-1) |
|
mask_E = (true_E != 0.0).any(dim=-1) |
|
|
|
flat_true_X = true_X[mask_X, :] |
|
flat_pred_X = masked_pred_X[mask_X, :] |
|
|
|
flat_true_E = true_E[mask_E, :] |
|
flat_pred_E = masked_pred_E[mask_E, :] |
|
|
|
target_X = torch.argmax(flat_true_X, dim=-1) |
|
loss_X = F.cross_entropy(flat_pred_X, target_X, reduction="mean") |
|
|
|
target_E = torch.argmax(flat_true_E, dim=-1) |
|
loss_E = F.cross_entropy(flat_pred_E, target_E, reduction="mean") |
|
|
|
total_loss = self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E |
|
|
|
return total_loss |
|
|