danielzuegner-ms's picture
Upload folder using huggingface_hub
17e1388 verified
adapter:
adapter:
_target_: mattergen.adapter.GemNetTAdapter
atom_type_diffusion: mask
denoise_atom_types: true
gemnet:
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: true
condition_on_adapt:
- dft_band_gap
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt:
dft_band_gap:
_target_: mattergen.property_embeddings.PropertyEmbedding
conditional_embedding_module:
_target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
d_model: 512
name: dft_band_gap
scaler:
_target_: mattergen.common.utils.data_utils.StandardScalerTorch
unconditional_embedding_module:
_target_: mattergen.property_embeddings.EmbeddingVector
hidden_dim: 512
full_finetuning: true
load_epoch: last
model_path: checkpoints/mattergen_base
data_module:
_recursive_: true
_target_: mattergen.common.data.datamodule.CrystDataModule
average_density: 0.05771451654022283
batch_size:
train: 64
val: 64
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
max_epochs: 2200
num_workers:
train: 0
val: 0
properties:
- dft_band_gap
root_dir: datasets/cache/alex_mp_20
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/train
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties:
- dft_band_gap
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/val
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties:
- dft_band_gap
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
diffusion_module:
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
discrete_corruptions:
atomic_numbers:
_target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
d3pm:
_target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
dim: 101
schedule:
_target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
kind: standard
num_steps: 1000
offset: 1
sdes:
cell:
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
vpsde_config:
beta_max: 20
beta_min: 0.1
limit_density: 0.05771451654022283
limit_var_scaling_constant: 0.25
pos:
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
limit_info_key: num_atoms
sigma_max: 5.0
wrapping_boundary: 1.0
loss_fn:
_target_: mattergen.common.loss.MaterialsLoss
d3pm_hybrid_lambda: 0.01
include_atomic_numbers: true
include_cell: true
include_pos: true
reduce: sum
weights:
atomic_numbers: 1.0
cell: 1.0
pos: 0.1
model:
_target_: mattergen.adapter.GemNetTAdapter
atom_type_diffusion: mask
denoise_atom_types: true
gemnet:
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: true
condition_on_adapt:
- dft_band_gap
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt:
dft_band_gap:
_target_: mattergen.property_embeddings.PropertyEmbedding
conditional_embedding_module:
_target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
d_model: 512
name: dft_band_gap
scaler:
_target_: mattergen.common.utils.data_utils.StandardScalerTorch
unconditional_embedding_module:
_target_: mattergen.property_embeddings.EmbeddingVector
hidden_dim: 512
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
dropout_fields_iid: false
p_unconditional: 0.2
optimizer_partial:
_partial_: true
_target_: torch.optim.Adam
lr: 5.0e-06
scheduler_partials:
- frequency: 1
interval: epoch
monitor: loss_train
scheduler:
_partial_: true
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
min_lr: 1.0e-06
patience: 100
verbose: true
strict: true
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
callbacks:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
log_momentum: false
logging_interval: step
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
every_n_epochs: 1
filename: '{epoch}-{loss_val:.2f}'
mode: min
monitor: loss_val
save_last: true
save_top_k: 1
verbose: false
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 50
- _target_: mattergen.common.data.callback.SetPropertyScalers
check_val_every_n_epoch: 5
devices: 8
gradient_clip_algorithm: value
gradient_clip_val: 0.5
logger:
_target_: pytorch_lightning.loggers.WandbLogger
job_type: train_finetune
project: crystal-generation
settings:
_save_requirements: false
_target_: wandb.Settings
start_method: fork
max_epochs: 200
num_nodes: 1
precision: 32
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true