|
adapter: |
|
adapter: |
|
_target_: mattergen.adapter.GemNetTAdapter |
|
atom_type_diffusion: mask |
|
denoise_atom_types: true |
|
gemnet: |
|
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl |
|
atom_embedding: |
|
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding |
|
emb_size: 512 |
|
with_mask_type: true |
|
condition_on_adapt: |
|
- dft_band_gap |
|
cutoff: 7.0 |
|
emb_size_atom: 512 |
|
emb_size_edge: 512 |
|
latent_dim: 512 |
|
max_cell_images_per_dim: 5 |
|
max_neighbors: 50 |
|
num_blocks: 4 |
|
num_targets: 1 |
|
otf_graph: true |
|
regress_stress: true |
|
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json |
|
hidden_dim: 512 |
|
property_embeddings: {} |
|
property_embeddings_adapt: |
|
dft_band_gap: |
|
_target_: mattergen.property_embeddings.PropertyEmbedding |
|
conditional_embedding_module: |
|
_target_: mattergen.diffusion.model_utils.NoiseLevelEncoding |
|
d_model: 512 |
|
name: dft_band_gap |
|
scaler: |
|
_target_: mattergen.common.utils.data_utils.StandardScalerTorch |
|
unconditional_embedding_module: |
|
_target_: mattergen.property_embeddings.EmbeddingVector |
|
hidden_dim: 512 |
|
full_finetuning: true |
|
load_epoch: last |
|
model_path: checkpoints/mattergen_base |
|
data_module: |
|
_recursive_: true |
|
_target_: mattergen.common.data.datamodule.CrystDataModule |
|
average_density: 0.05771451654022283 |
|
batch_size: |
|
train: 64 |
|
val: 64 |
|
dataset_transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties |
|
max_epochs: 2200 |
|
num_workers: |
|
train: 0 |
|
val: 0 |
|
properties: |
|
- dft_band_gap |
|
root_dir: datasets/cache/alex_mp_20 |
|
train_dataset: |
|
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path |
|
cache_path: datasets/cache/alex_mp_20/train |
|
dataset_transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties |
|
properties: |
|
- dft_band_gap |
|
transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.symmetrize_lattice |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.set_chemical_system_string |
|
transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.symmetrize_lattice |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.set_chemical_system_string |
|
val_dataset: |
|
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path |
|
cache_path: datasets/cache/alex_mp_20/val |
|
dataset_transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties |
|
properties: |
|
- dft_band_gap |
|
transforms: |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.symmetrize_lattice |
|
- _partial_: true |
|
_target_: mattergen.common.data.transform.set_chemical_system_string |
|
lightning_module: |
|
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule |
|
diffusion_module: |
|
_target_: mattergen.diffusion.diffusion_module.DiffusionModule |
|
corruption: |
|
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption |
|
discrete_corruptions: |
|
atomic_numbers: |
|
_target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption |
|
d3pm: |
|
_target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion |
|
dim: 101 |
|
schedule: |
|
_target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule |
|
kind: standard |
|
num_steps: 1000 |
|
offset: 1 |
|
sdes: |
|
cell: |
|
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config |
|
vpsde_config: |
|
beta_max: 20 |
|
beta_min: 0.1 |
|
limit_density: 0.05771451654022283 |
|
limit_var_scaling_constant: 0.25 |
|
pos: |
|
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE |
|
limit_info_key: num_atoms |
|
sigma_max: 5.0 |
|
wrapping_boundary: 1.0 |
|
loss_fn: |
|
_target_: mattergen.common.loss.MaterialsLoss |
|
d3pm_hybrid_lambda: 0.01 |
|
include_atomic_numbers: true |
|
include_cell: true |
|
include_pos: true |
|
reduce: sum |
|
weights: |
|
atomic_numbers: 1.0 |
|
cell: 1.0 |
|
pos: 0.1 |
|
model: |
|
_target_: mattergen.adapter.GemNetTAdapter |
|
atom_type_diffusion: mask |
|
denoise_atom_types: true |
|
gemnet: |
|
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl |
|
atom_embedding: |
|
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding |
|
emb_size: 512 |
|
with_mask_type: true |
|
condition_on_adapt: |
|
- dft_band_gap |
|
cutoff: 7.0 |
|
emb_size_atom: 512 |
|
emb_size_edge: 512 |
|
latent_dim: 512 |
|
max_cell_images_per_dim: 5 |
|
max_neighbors: 50 |
|
num_blocks: 4 |
|
num_targets: 1 |
|
otf_graph: true |
|
regress_stress: true |
|
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json |
|
hidden_dim: 512 |
|
property_embeddings: {} |
|
property_embeddings_adapt: |
|
dft_band_gap: |
|
_target_: mattergen.property_embeddings.PropertyEmbedding |
|
conditional_embedding_module: |
|
_target_: mattergen.diffusion.model_utils.NoiseLevelEncoding |
|
d_model: 512 |
|
name: dft_band_gap |
|
scaler: |
|
_target_: mattergen.common.utils.data_utils.StandardScalerTorch |
|
unconditional_embedding_module: |
|
_target_: mattergen.property_embeddings.EmbeddingVector |
|
hidden_dim: 512 |
|
pre_corruption_fn: |
|
_target_: mattergen.property_embeddings.SetEmbeddingType |
|
dropout_fields_iid: false |
|
p_unconditional: 0.2 |
|
optimizer_partial: |
|
_partial_: true |
|
_target_: torch.optim.Adam |
|
lr: 5.0e-06 |
|
scheduler_partials: |
|
- frequency: 1 |
|
interval: epoch |
|
monitor: loss_train |
|
scheduler: |
|
_partial_: true |
|
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau |
|
factor: 0.6 |
|
min_lr: 1.0e-06 |
|
patience: 100 |
|
verbose: true |
|
strict: true |
|
trainer: |
|
_target_: pytorch_lightning.Trainer |
|
accelerator: gpu |
|
accumulate_grad_batches: 1 |
|
callbacks: |
|
- _target_: pytorch_lightning.callbacks.LearningRateMonitor |
|
log_momentum: false |
|
logging_interval: step |
|
- _target_: pytorch_lightning.callbacks.ModelCheckpoint |
|
every_n_epochs: 1 |
|
filename: '{epoch}-{loss_val:.2f}' |
|
mode: min |
|
monitor: loss_val |
|
save_last: true |
|
save_top_k: 1 |
|
verbose: false |
|
- _target_: pytorch_lightning.callbacks.TQDMProgressBar |
|
refresh_rate: 50 |
|
- _target_: mattergen.common.data.callback.SetPropertyScalers |
|
check_val_every_n_epoch: 5 |
|
devices: 8 |
|
gradient_clip_algorithm: value |
|
gradient_clip_val: 0.5 |
|
logger: |
|
_target_: pytorch_lightning.loggers.WandbLogger |
|
job_type: train_finetune |
|
project: crystal-generation |
|
settings: |
|
_save_requirements: false |
|
_target_: wandb.Settings |
|
start_method: fork |
|
max_epochs: 200 |
|
num_nodes: 1 |
|
precision: 32 |
|
strategy: |
|
_target_: pytorch_lightning.strategies.ddp.DDPStrategy |
|
find_unused_parameters: true |
|
|