Llamole / src /model /graph_decoder /diffusion_model.py
msun415's picture
Upload folder using huggingface_hub
13362e2 verified
# Copyright 2024 the Llamole team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import diffusion_utils as utils
from .molecule_utils import graph_to_smiles, check_valid
from .transformer import Transformer
class GraphDiT(nn.Module):
def __init__(
self,
model_config_path,
data_info_path,
model_dtype,
):
super().__init__()
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
input_dims = data_info.input_dims
output_dims = data_info.output_dims
nodes_dist = data_info.nodes_dist
active_index = data_info.active_index
self.model_config = dm_cfg
self.data_info = data_info
self.T = dm_cfg.diffusion_steps
self.guide_scale = dm_cfg.guide_scale
self.Xdim = input_dims["X"]
self.Edim = input_dims["E"]
self.ydim = input_dims["y"]
self.Xdim_output = output_dims["X"]
self.Edim_output = output_dims["E"]
self.ydim_output = output_dims["y"]
self.node_dist = nodes_dist
self.active_index = active_index
self.max_n_nodes = data_info.max_n_nodes
self.train_loss = TrainLossDiscrete(dm_cfg.lambda_train)
self.atom_decoder = data_info.atom_decoder
self.hidden_size = dm_cfg.hidden_size
self.text_input_size = 768
self.denoiser = Transformer(
max_n_nodes=self.max_n_nodes,
hidden_size=dm_cfg.hidden_size,
depth=dm_cfg.depth,
num_heads=dm_cfg.num_heads,
mlp_ratio=dm_cfg.mlp_ratio,
drop_condition=dm_cfg.drop_condition,
Xdim=self.Xdim,
Edim=self.Edim,
ydim=self.ydim,
text_dim=self.text_input_size
)
self.model_dtype = model_dtype
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
)
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
data_info.node_types.to(self.model_dtype)
)
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
data_info.edge_types.to(self.model_dtype)
)
x_marginals = x_marginals / x_marginals.sum()
e_marginals = e_marginals / e_marginals.sum()
xe_conditions = data_info.transition_E.to(self.model_dtype)
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
xe_conditions = xe_conditions.sum(dim=1)
ex_conditions = xe_conditions.t()
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
self.transition_model = utils.MarginalTransition(
x_marginals=x_marginals,
e_marginals=e_marginals,
xe_conditions=xe_conditions,
ex_conditions=ex_conditions,
y_classes=self.ydim_output,
n_nodes=self.max_n_nodes,
)
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
def init_model(self, model_dir, verbose=False):
model_file = os.path.join(model_dir, 'model.pt')
if os.path.exists(model_file):
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
else:
raise FileNotFoundError(f"Model file not found: {model_file}")
if verbose:
print('GraphDiT Denoiser Model initialized.')
print('Denoiser model:\n', self.denoiser)
def save_pretrained(self, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save model
model_path = os.path.join(output_dir, 'model.pt')
torch.save(self.denoiser.state_dict(), model_path)
# Save model config
config_path = os.path.join(output_dir, 'model_config.yaml')
with open(config_path, 'w') as f:
yaml.dump(vars(self.model_config), f)
# Save data info
data_info_path = os.path.join(output_dir, 'data.meta.json')
data_info_dict = {
"active_atoms": self.data_info.active_atoms,
"max_node": self.data_info.max_n_nodes,
"n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(),
"bond_type_dist": self.data_info.edge_types.tolist(),
"transition_E": self.data_info.transition_E.tolist(),
"atom_type_dist": self.data_info.node_types.tolist(),
"valencies": self.data_info.valency_distribution.tolist()
}
with open(data_info_path, 'w') as f:
json.dump(data_info_dict, f, indent=2)
print('GraphDiT Model and configurations saved to:', output_dir)
def disable_grads(self):
self.denoiser.disable_grads()
def forward(
self, x, edge_index, edge_attr, graph_batch, properties, text_embedding, no_label_index
):
properties = torch.where(properties == no_label_index, float("nan"), properties)
data_x = F.one_hot(x, num_classes=118).to(self.model_dtype)[
:, self.active_index
]
data_edge_attr = F.one_hot(edge_attr, num_classes=5).to(self.model_dtype)
dense_data, node_mask = utils.to_dense(
data_x, edge_index, data_edge_attr, graph_batch, self.max_n_nodes
)
X, E = dense_data.X, dense_data.E
dense_data = dense_data.mask(node_mask)
noisy_data = self.apply_noise(X, E, properties, node_mask)
pred = self._forward(noisy_data, text_embedding)
loss = self.train_loss(
masked_pred_X=pred.X,
masked_pred_E=pred.E,
true_X=X,
true_E=E,
node_mask=node_mask,
)
return loss
def _forward(self, noisy_data, text_embedding, unconditioned=False):
noisy_x, noisy_e, properties = (
noisy_data["X_t"].to(self.model_dtype),
noisy_data["E_t"].to(self.model_dtype),
noisy_data["y_t"].to(self.model_dtype).clone(),
)
node_mask, timestep, text_embedding = (
noisy_data["node_mask"],
noisy_data["t"],
text_embedding.to(self.model_dtype),
)
pred = self.denoiser(
noisy_x,
noisy_e,
node_mask,
properties,
text_embedding,
timestep,
unconditioned=unconditioned,
)
return pred
def apply_noise(self, X, E, y, node_mask):
"""Sample noise and apply it to the data."""
# Sample a timestep t.
# When evaluating, the loss for t=0 is computed separately
lowest_t = 0 if self.training else 1
t_int = torch.randint(
lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
).to(
self.model_dtype
) # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.T
s_float = s_int / self.T
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(
alpha_t_bar, X.device
) # (bs, dx_in, dx_out), (bs, de_in, de_out)
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, : self.Xdim_output]
probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
sampled_t = utils.sample_discrete_features(
probX=probX, probE=probE, node_mask=node_mask
)
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
y_t = y
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
noisy_data = {
"t_int": t_int,
"t": t_float,
"beta_t": beta_t,
"alpha_s_bar": alpha_s_bar,
"alpha_t_bar": alpha_t_bar,
"X_t": z_t.X,
"E_t": z_t.E,
"y_t": z_t.y,
"node_mask": node_mask,
}
return noisy_data
@torch.no_grad()
def generate(
self,
properties,
text_embedding,
no_label_index,
):
properties = torch.where(properties == no_label_index, float("nan"), properties)
batch_size = properties.size(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_nodes = self.node_dist.sample_n(batch_size, device)
arange = (
torch.arange(self.max_n_nodes, device=device)
.unsqueeze(0)
.expand(batch_size, -1)
)
node_mask = arange < n_nodes.unsqueeze(1)
z_T = utils.sample_discrete_feature_noise(
limit_dist=self.limit_dist, node_mask=node_mask
)
X, E = z_T.X, z_T.E
assert (E == torch.transpose(E, 1, 2)).all()
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
y = properties
for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
t_array = s_array + 1
s_norm = s_array / self.T
t_norm = t_array / self.T
# Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
s_norm, t_norm, X, E, y, text_embedding, node_mask
)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
for i in range(batch_size):
n = n_nodes[i]
atom_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
smiles_list = graph_to_smiles(molecule_list, self.atom_decoder)
return smiles_list
def check_valid(self, smiles):
return check_valid(smiles)
def sample_p_zs_given_zt(
self, s, t, X_t, E_t, properties, text_embedding, node_mask
):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, _ = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Neural net predictions
noisy_data = {
"X_t": X_t,
"E_t": E_t,
"y_t": properties,
"t": t,
"node_mask": node_mask,
}
def get_prob(noisy_data, text_embedding, unconditioned=False):
pred = self._forward(noisy_data, text_embedding, unconditioned=unconditioned)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
device = text_embedding.device
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
Qt = self.transition_model.get_Qt(beta_t, device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
unnormalized_probX_all = utils.reverse_diffusion(
predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
)
unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
unnormalized_prob_E = unnormalized_probX_all[
:, :, self.Xdim_output :
].reshape(bs, n * n, -1)
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(
unnormalized_prob_X, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = unnormalized_prob_E / torch.sum(
unnormalized_prob_E, dim=-1, keepdim=True
) # bs, n, d_t-1
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
return prob_X, prob_E
prob_X, prob_E = get_prob(noisy_data, text_embedding)
### Guidance
if self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E = get_prob(
noisy_data, text_embedding, unconditioned=True
)
prob_X = (
uncon_prob_X
* (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale
)
prob_E = (
uncon_prob_E
* (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale
)
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
sampled_s = utils.sample_discrete_features(
prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
)
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
node_mask, collapse=True
).type_as(properties)
class TrainLossDiscrete(nn.Module):
"""Train with Cross entropy"""
def __init__(self, lambda_train):
super().__init__()
self.lambda_train = lambda_train
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, node_mask):
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
masked_pred_X = torch.reshape(
masked_pred_X, (-1, masked_pred_X.size(-1))
) # (bs * n, dx)
masked_pred_E = torch.reshape(
masked_pred_E, (-1, masked_pred_E.size(-1))
) # (bs * n * n, de)
# Remove masked rows
mask_X = (true_X != 0.0).any(dim=-1)
mask_E = (true_E != 0.0).any(dim=-1)
flat_true_X = true_X[mask_X, :]
flat_pred_X = masked_pred_X[mask_X, :]
flat_true_E = true_E[mask_E, :]
flat_pred_E = masked_pred_E[mask_E, :]
target_X = torch.argmax(flat_true_X, dim=-1)
loss_X = F.cross_entropy(flat_pred_X, target_X, reduction="mean")
target_E = torch.argmax(flat_true_E, dim=-1)
loss_E = F.cross_entropy(flat_pred_E, target_E, reduction="mean")
total_loss = self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E
return total_loss