Upload 12 files
Browse files- config.yaml +127 -0
- diffusion.py +1434 -0
- dit.py +388 -0
- ema.py +97 -0
- esm_utils.py +15 -0
- generate.py +60 -0
- main.py +250 -0
- mdlm_motif_benchmarking.py +96 -0
- mlm_generate_utils.py +108 -0
- noise_schedule.py +153 -0
- pl_data_loader.py +819 -0
- utils.py +230 -0
config.yaml
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
| 4 |
+
- /model: small
|
| 5 |
+
- /strategy: ddp
|
| 6 |
+
- /noise: loglinear
|
| 7 |
+
- /lr_scheduler: constant_warmup
|
| 8 |
+
|
| 9 |
+
mode: sample_eval # train / ppl_eval / sample_eval
|
| 10 |
+
diffusion: absorbing_state
|
| 11 |
+
backbone: membrane_esm_finetune # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune
|
| 12 |
+
parameterization: subs # subs / d3pm / sedd
|
| 13 |
+
time_conditioning: False
|
| 14 |
+
T: 0 # 0 (continuous time) / 1000
|
| 15 |
+
subs_masking: False
|
| 16 |
+
|
| 17 |
+
seed: 42
|
| 18 |
+
|
| 19 |
+
data:
|
| 20 |
+
train:
|
| 21 |
+
vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv
|
| 22 |
+
membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv
|
| 23 |
+
wrap: null
|
| 24 |
+
test:
|
| 25 |
+
vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv
|
| 26 |
+
membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv
|
| 27 |
+
wrap: null
|
| 28 |
+
valid:
|
| 29 |
+
vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv
|
| 30 |
+
membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv
|
| 31 |
+
wrap: null
|
| 32 |
+
wrapping: True
|
| 33 |
+
|
| 34 |
+
loader:
|
| 35 |
+
global_batch_size: 8
|
| 36 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 37 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 38 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 39 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 40 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 41 |
+
pin_memory: True
|
| 42 |
+
|
| 43 |
+
sampling:
|
| 44 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 45 |
+
steps: 128
|
| 46 |
+
noise_removal: True
|
| 47 |
+
# TODO(yair): @subham, why aren't these params under `eval`?
|
| 48 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 49 |
+
num_sample_log: 2
|
| 50 |
+
semi_ar: False
|
| 51 |
+
stride_length: 1
|
| 52 |
+
num_strides: 1
|
| 53 |
+
|
| 54 |
+
training:
|
| 55 |
+
ema: 0.9999
|
| 56 |
+
antithetic_sampling: True
|
| 57 |
+
importance_sampling: False
|
| 58 |
+
sampling_eps: 1e-3
|
| 59 |
+
change_of_variables: False
|
| 60 |
+
mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch
|
| 61 |
+
esm_model_path: facebook/esm2_t30_150M_UR50D
|
| 62 |
+
focus_mask: False
|
| 63 |
+
|
| 64 |
+
eval:
|
| 65 |
+
checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training.
|
| 66 |
+
disable_ema: False
|
| 67 |
+
compute_generative_perplexity: False
|
| 68 |
+
perplexity_batch_size: 8
|
| 69 |
+
compute_perplexity_on_sanity: False
|
| 70 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 71 |
+
generate_samples: True
|
| 72 |
+
generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
|
| 73 |
+
|
| 74 |
+
optim:
|
| 75 |
+
weight_decay: 0.075
|
| 76 |
+
lr: 3e-4
|
| 77 |
+
beta1: 0.9
|
| 78 |
+
beta2: 0.999
|
| 79 |
+
eps: 1e-8
|
| 80 |
+
|
| 81 |
+
Model:
|
| 82 |
+
hidden_size: 1280
|
| 83 |
+
cond_dim: 256
|
| 84 |
+
n_heads: 20
|
| 85 |
+
n_blocks: 4
|
| 86 |
+
dropout: 0.5
|
| 87 |
+
length: null #512
|
| 88 |
+
scale_by_sigma: True
|
| 89 |
+
|
| 90 |
+
trainer:
|
| 91 |
+
_target_: lightning.Trainer
|
| 92 |
+
accelerator: cuda
|
| 93 |
+
num_nodes: 1
|
| 94 |
+
devices: ${device_count:}
|
| 95 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 96 |
+
gradient_clip_val: 1.0
|
| 97 |
+
precision: bf16
|
| 98 |
+
num_sanity_val_steps: 2
|
| 99 |
+
max_epochs: 60
|
| 100 |
+
max_steps: 1_000_000
|
| 101 |
+
log_every_n_steps: 10
|
| 102 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 103 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 104 |
+
val_check_interval: 955
|
| 105 |
+
|
| 106 |
+
wandb:
|
| 107 |
+
project: MDpLM_finetune_membrane_200k-seqs
|
| 108 |
+
notes: null
|
| 109 |
+
group: programmablebio
|
| 110 |
+
job_type: null
|
| 111 |
+
name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16
|
| 112 |
+
id: ${.name}_${seed}
|
| 113 |
+
|
| 114 |
+
hydra:
|
| 115 |
+
run:
|
| 116 |
+
dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
|
| 117 |
+
job:
|
| 118 |
+
chdir: true
|
| 119 |
+
|
| 120 |
+
checkpointing:
|
| 121 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 122 |
+
save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
|
| 123 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 124 |
+
resume_from_ckpt: false
|
| 125 |
+
resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt
|
| 126 |
+
pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
|
| 127 |
+
finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
|
diffusion.py
ADDED
|
@@ -0,0 +1,1434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import typing
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import hydra.utils
|
| 9 |
+
import lightning as L
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch
|
| 13 |
+
# import dit
|
| 14 |
+
import ema
|
| 15 |
+
import time
|
| 16 |
+
import gc
|
| 17 |
+
import pl_data_loader as dataloader
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torchmetrics
|
| 20 |
+
import transformers
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 23 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
| 24 |
+
|
| 25 |
+
import utils
|
| 26 |
+
import noise_schedule
|
| 27 |
+
|
| 28 |
+
LOG2 = math.log(2)
|
| 29 |
+
|
| 30 |
+
class CosineWarmup(_LRScheduler):
|
| 31 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 32 |
+
self.warmup_steps = warmup_steps
|
| 33 |
+
self.total_steps = total_steps
|
| 34 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 35 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 36 |
+
|
| 37 |
+
def get_lr(self):
|
| 38 |
+
if self.last_epoch < self.warmup_steps:
|
| 39 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 40 |
+
|
| 41 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 42 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 43 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 44 |
+
|
| 45 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _sample_categorical(categorical_probs):
|
| 49 |
+
gumbel_norm = (
|
| 50 |
+
1e-10
|
| 51 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 52 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _unsqueeze(x, reference):
|
| 56 |
+
return x.view(
|
| 57 |
+
* x.shape,
|
| 58 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class Loss:
|
| 63 |
+
loss: torch.FloatTensor
|
| 64 |
+
nlls: torch.FloatTensor
|
| 65 |
+
token_mask: torch.FloatTensor
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BPD(NLL):
|
| 73 |
+
def compute(self) -> Tensor:
|
| 74 |
+
"""Computes the bits per dimension.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
bpd
|
| 78 |
+
"""
|
| 79 |
+
return self.mean_value / self.weight / LOG2
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Perplexity(NLL):
|
| 83 |
+
def compute(self) -> Tensor:
|
| 84 |
+
"""Computes the Perplexity.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Perplexity
|
| 88 |
+
"""
|
| 89 |
+
return torch.exp(self.mean_value / self.weight)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class WrapVanillaESM(nn.Module):
|
| 93 |
+
def __init__(self, bert_model_path):
|
| 94 |
+
super(WrapVanillaESM, self).__init__()
|
| 95 |
+
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 96 |
+
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
|
| 97 |
+
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
|
| 98 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def __call__(self, *args, **kwargs):
|
| 102 |
+
return self.model(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
def unfreeze_attn_layers(self):
|
| 105 |
+
model_layers = len(self.model.esm.encoder.layer)
|
| 106 |
+
|
| 107 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
| 108 |
+
if i >= model_layers-5: # fine-tune only last n layers
|
| 109 |
+
for module in layer.attention.self.key.modules():
|
| 110 |
+
for param in module.parameters():
|
| 111 |
+
param.requires_grad = True
|
| 112 |
+
for module in layer.attention.self.query.modules():
|
| 113 |
+
for param in module.parameters():
|
| 114 |
+
param.requires_grad = True
|
| 115 |
+
for module in layer.attention.self.value.modules():
|
| 116 |
+
for param in module.parameters():
|
| 117 |
+
param.requires_grad = True
|
| 118 |
+
|
| 119 |
+
def unfreeze_all_layers(self):
|
| 120 |
+
for param in self.model.parameters():
|
| 121 |
+
param.requires_grad = True
|
| 122 |
+
|
| 123 |
+
def forward(self, inputs, sigma, attention_mask):
|
| 124 |
+
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
|
| 125 |
+
return logits
|
| 126 |
+
|
| 127 |
+
def save_model(self, save_dir):
|
| 128 |
+
self.model.save_pretrained(save_dir)
|
| 129 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 130 |
+
|
| 131 |
+
def load_model(self, load_dir):
|
| 132 |
+
self.model = AutoModel.from_pretrained(load_dir)
|
| 133 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
|
| 134 |
+
|
| 135 |
+
class WrapMembraneESM(nn.Module):
|
| 136 |
+
def __init__(self, bert_model_path):
|
| 137 |
+
super(WrapMembraneESM, self).__init__()
|
| 138 |
+
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 139 |
+
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
|
| 140 |
+
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
|
| 141 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
|
| 142 |
+
|
| 143 |
+
def __call__(self, *args, **kwargs):
|
| 144 |
+
return self.model(*args, **kwargs)
|
| 145 |
+
|
| 146 |
+
def freeze_model(self):
|
| 147 |
+
for param in self.model.parameters():
|
| 148 |
+
param.requires_grad = False
|
| 149 |
+
|
| 150 |
+
def unfreeze_all_layers(self):
|
| 151 |
+
for param in self.model.parameters():
|
| 152 |
+
param.requires_grad = True
|
| 153 |
+
|
| 154 |
+
def unfreeze_attn_layers(self):
|
| 155 |
+
model_layers = len(self.model.esm.encoder.layer)
|
| 156 |
+
|
| 157 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
| 158 |
+
if i >= model_layers-11: # fine-tune only last n layers
|
| 159 |
+
for module in layer.attention.self.key.modules():
|
| 160 |
+
for param in module.parameters():
|
| 161 |
+
param.requires_grad = True
|
| 162 |
+
for module in layer.attention.self.query.modules():
|
| 163 |
+
for param in module.parameters():
|
| 164 |
+
param.requires_grad = True
|
| 165 |
+
for module in layer.attention.self.value.modules():
|
| 166 |
+
for param in module.parameters():
|
| 167 |
+
param.requires_grad = True
|
| 168 |
+
|
| 169 |
+
def forward(self, inputs, sigma, attention_mask):
|
| 170 |
+
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
|
| 171 |
+
return logits
|
| 172 |
+
|
| 173 |
+
def save_model(self, save_dir):
|
| 174 |
+
self.model.save_pretrained(save_dir)
|
| 175 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 176 |
+
|
| 177 |
+
def load_model(self, load_dir):
|
| 178 |
+
self.model = AutoModel.from_pretrained(load_dir)
|
| 179 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
|
| 180 |
+
|
| 181 |
+
class Diffusion(L.LightningModule):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
config,
|
| 185 |
+
tokenizer: transformers.PreTrainedTokenizer):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.save_hyperparameters()
|
| 188 |
+
self.config = config
|
| 189 |
+
|
| 190 |
+
self.tokenizer = tokenizer
|
| 191 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 192 |
+
self.sampler = self.config.sampling.predictor
|
| 193 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.\
|
| 194 |
+
gen_ppl_eval_model_name_or_path
|
| 195 |
+
self.antithetic_sampling = self.config.training.antithetic_sampling
|
| 196 |
+
self.importance_sampling = self.config.training.importance_sampling
|
| 197 |
+
self.change_of_variables = self.config.training.change_of_variables
|
| 198 |
+
if (not hasattr(self.tokenizer, 'mask_token')
|
| 199 |
+
or self.tokenizer.mask_token is None):
|
| 200 |
+
self.mask_index = self.vocab_size
|
| 201 |
+
self.vocab_size += 1
|
| 202 |
+
else:
|
| 203 |
+
self.mask_index = self.tokenizer.mask_token_id
|
| 204 |
+
self.parameterization = self.config.parameterization
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# if self.config.backbone == 'dit':
|
| 208 |
+
# self.backbone = dit.DIT(
|
| 209 |
+
# self.config, vocab_size=self.vocab_size, mlm_model_path=config.training.mlm_model_path)
|
| 210 |
+
if self.config.backbone == "vanilla_esm_pretrain":
|
| 211 |
+
self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path)
|
| 212 |
+
self.backbone.unfreeze_all_layers()
|
| 213 |
+
self.backbone = torch.compile(self.backbone)
|
| 214 |
+
elif self.config.backbone == 'membrane_esm_finetune':
|
| 215 |
+
self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path)
|
| 216 |
+
self.backbone.unfreeze_all_layers()
|
| 217 |
+
# self.backbone = torch.compile(self.backbone)
|
| 218 |
+
|
| 219 |
+
# elif self.config.backbone == 'dimamba':
|
| 220 |
+
# self.backbone = dimamba.DiMamba(
|
| 221 |
+
# self.config,
|
| 222 |
+
# vocab_size=self.vocab_size,
|
| 223 |
+
# pad_token_id=self.tokenizer.pad_token_id)
|
| 224 |
+
# elif self.config.backbone == 'ar':
|
| 225 |
+
# self.backbone = autoregressive.AR(
|
| 226 |
+
# self.config,
|
| 227 |
+
# vocab_size=self.vocab_size,
|
| 228 |
+
# mask_index=self.mask_index)
|
| 229 |
+
# elif self.config.backbone == 'hf_dit':
|
| 230 |
+
# self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
|
| 231 |
+
# config.eval.checkpoint_path, trust_remote_code=True)
|
| 232 |
+
# else:
|
| 233 |
+
# raise ValueError(
|
| 234 |
+
# f'Unknown backbone: {self.config.backbone}')
|
| 235 |
+
|
| 236 |
+
self.T = self.config.T
|
| 237 |
+
self.subs_masking = self.config.subs_masking
|
| 238 |
+
|
| 239 |
+
self.softplus = torch.nn.Softplus()
|
| 240 |
+
# metrics are automatically reset at end of epoch
|
| 241 |
+
metrics = torchmetrics.MetricCollection({
|
| 242 |
+
'nll': NLL(),
|
| 243 |
+
'bpd': BPD(),
|
| 244 |
+
'ppl': Perplexity(),
|
| 245 |
+
})
|
| 246 |
+
metrics.set_dtype(torch.float64)
|
| 247 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
| 248 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 249 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 250 |
+
|
| 251 |
+
# generative perplexity
|
| 252 |
+
self.gen_ppl_metric = Perplexity()
|
| 253 |
+
self.eval_model_tokenizer = transformers.AutoTokenizer.\
|
| 254 |
+
from_pretrained(self.gen_ppl_eval_model_name_or_path)
|
| 255 |
+
if self.eval_model_tokenizer.pad_token is None:
|
| 256 |
+
self.eval_model_tokenizer.pad_token =\
|
| 257 |
+
self.eval_model_tokenizer.eos_token
|
| 258 |
+
self.eval_model_tokenizer.pad_token_id =\
|
| 259 |
+
self.eval_model_tokenizer.eos_token_id
|
| 260 |
+
|
| 261 |
+
self.noise = noise_schedule.get_noise(self.config,
|
| 262 |
+
dtype=self.dtype)
|
| 263 |
+
if self.config.training.ema > 0:
|
| 264 |
+
self.ema = ema.ExponentialMovingAverage(
|
| 265 |
+
itertools.chain(self.backbone.parameters(),
|
| 266 |
+
self.noise.parameters()),
|
| 267 |
+
decay=self.config.training.ema)
|
| 268 |
+
else:
|
| 269 |
+
self.ema = None
|
| 270 |
+
|
| 271 |
+
self.lr = self.config.optim.lr
|
| 272 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 273 |
+
self.time_conditioning = self.config.time_conditioning
|
| 274 |
+
self.neg_infinity = -1000000.0
|
| 275 |
+
self.fast_forward_epochs = None
|
| 276 |
+
self.fast_forward_batches = None
|
| 277 |
+
self._validate_configuration()
|
| 278 |
+
|
| 279 |
+
def _validate_configuration(self):
|
| 280 |
+
assert not (self.change_of_variables
|
| 281 |
+
and self.importance_sampling)
|
| 282 |
+
if self.parameterization == 'sedd':
|
| 283 |
+
assert not self.importance_sampling
|
| 284 |
+
assert not self.change_of_variables
|
| 285 |
+
if self.parameterization == 'd3pm':
|
| 286 |
+
assert self.T > 0
|
| 287 |
+
if self.T > 0:
|
| 288 |
+
assert self.parameterization in {'d3pm', 'subs'}
|
| 289 |
+
if self.subs_masking:
|
| 290 |
+
assert self.parameterization == 'd3pm'
|
| 291 |
+
|
| 292 |
+
def on_load_checkpoint(self, checkpoint):
|
| 293 |
+
if self.ema:
|
| 294 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
| 295 |
+
# Copied from:
|
| 296 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
|
| 297 |
+
self.fast_forward_epochs = checkpoint['loops'][
|
| 298 |
+
'fit_loop']['epoch_progress']['current']['completed']
|
| 299 |
+
self.fast_forward_batches = checkpoint['loops'][
|
| 300 |
+
'fit_loop']['epoch_loop.batch_progress'][
|
| 301 |
+
'current']['completed']
|
| 302 |
+
|
| 303 |
+
def on_save_checkpoint(self, checkpoint):
|
| 304 |
+
if self.ema:
|
| 305 |
+
checkpoint['ema'] = self.ema.state_dict()
|
| 306 |
+
# Copied from:
|
| 307 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
|
| 308 |
+
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
|
| 309 |
+
# behind, so we're using the optimizer's progress.
|
| 310 |
+
checkpoint['loops']['fit_loop'][
|
| 311 |
+
'epoch_loop.batch_progress']['total'][
|
| 312 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 313 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 314 |
+
'optimizer']['step']['total'][
|
| 315 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 316 |
+
checkpoint['loops']['fit_loop'][
|
| 317 |
+
'epoch_loop.batch_progress']['current'][
|
| 318 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 319 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 320 |
+
'optimizer']['step']['current'][
|
| 321 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 322 |
+
# _batches_that_stepped tracks the number of global steps, not the number
|
| 323 |
+
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
|
| 324 |
+
checkpoint['loops']['fit_loop'][
|
| 325 |
+
'epoch_loop.state_dict'][
|
| 326 |
+
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
|
| 327 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 328 |
+
'optimizer']['step']['total']['completed']
|
| 329 |
+
if 'sampler' not in checkpoint.keys():
|
| 330 |
+
checkpoint['sampler'] = {}
|
| 331 |
+
if hasattr(self.trainer.train_dataloader.sampler,
|
| 332 |
+
'state_dict'):
|
| 333 |
+
sampler_state_dict = self.trainer.\
|
| 334 |
+
train_dataloader.sampler.state_dict()
|
| 335 |
+
checkpoint['sampler'][
|
| 336 |
+
'random_state'] = sampler_state_dict.get(
|
| 337 |
+
'random_state', None)
|
| 338 |
+
else:
|
| 339 |
+
checkpoint['sampler']['random_state'] = None
|
| 340 |
+
|
| 341 |
+
self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path)
|
| 342 |
+
|
| 343 |
+
def on_train_start(self):
|
| 344 |
+
torch.cuda.empty_cache()
|
| 345 |
+
if self.ema:
|
| 346 |
+
self.ema.move_shadow_params_to_device(self.device)
|
| 347 |
+
|
| 348 |
+
# Adapted from:
|
| 349 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
|
| 350 |
+
distributed = (
|
| 351 |
+
self.trainer._accelerator_connector.use_distributed_sampler
|
| 352 |
+
and self.trainer._accelerator_connector.is_distributed)
|
| 353 |
+
if distributed:
|
| 354 |
+
sampler_cls = dataloader.FaultTolerantDistributedSampler
|
| 355 |
+
else:
|
| 356 |
+
sampler_cls = dataloader.RandomFaultTolerantSampler
|
| 357 |
+
updated_dls = []
|
| 358 |
+
for dl in self.trainer.fit_loop._combined_loader.flattened:
|
| 359 |
+
if hasattr(dl.sampler, 'shuffle'):
|
| 360 |
+
dl_sampler = sampler_cls(
|
| 361 |
+
dl.dataset, shuffle=dl.sampler.shuffle)
|
| 362 |
+
else:
|
| 363 |
+
dl_sampler = sampler_cls(dl.dataset)
|
| 364 |
+
if (distributed
|
| 365 |
+
and self.fast_forward_epochs is not None
|
| 366 |
+
and self.fast_forward_batches is not None):
|
| 367 |
+
dl_sampler.load_state_dict({
|
| 368 |
+
'epoch': self.fast_forward_epochs,
|
| 369 |
+
'counter': (self.fast_forward_batches
|
| 370 |
+
* self.config.loader.batch_size)})
|
| 371 |
+
|
| 372 |
+
from functools import partial
|
| 373 |
+
from pl_data_loader import collate_fn
|
| 374 |
+
collate_partial = partial(collate_fn, tokenizer=self.tokenizer)
|
| 375 |
+
torch.cuda.empty_cache()
|
| 376 |
+
|
| 377 |
+
updated_dls.append(
|
| 378 |
+
torch.utils.data.DataLoader(
|
| 379 |
+
dl.dataset,
|
| 380 |
+
batch_size=self.config.loader.batch_size,
|
| 381 |
+
num_workers=self.config.loader.num_workers,
|
| 382 |
+
pin_memory=self.config.loader.pin_memory,
|
| 383 |
+
sampler=dl_sampler,
|
| 384 |
+
shuffle=False,
|
| 385 |
+
persistent_workers=False,
|
| 386 |
+
collate_fn=collate_partial))
|
| 387 |
+
self.trainer.fit_loop._combined_loader.flattened = updated_dls
|
| 388 |
+
|
| 389 |
+
def optimizer_step(self, *args, **kwargs):
|
| 390 |
+
super().optimizer_step(*args, **kwargs)
|
| 391 |
+
|
| 392 |
+
gc.collect()
|
| 393 |
+
torch.cuda.empty_cache()
|
| 394 |
+
|
| 395 |
+
if self.ema:
|
| 396 |
+
self.ema.update(itertools.chain(
|
| 397 |
+
self.backbone.parameters(),
|
| 398 |
+
self.noise.parameters()))
|
| 399 |
+
|
| 400 |
+
# optimizer_closure = kwargs.get('optimizer_closure', None)
|
| 401 |
+
|
| 402 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 403 |
+
# self.backbone.parameters(),
|
| 404 |
+
# self.noise.parameters()
|
| 405 |
+
# ) if p.requires_grad and p.grad_fn is not None]
|
| 406 |
+
|
| 407 |
+
# # if params_with_grad:
|
| 408 |
+
# # super().optimizer_step(closure=optimizer_closure)
|
| 409 |
+
|
| 410 |
+
# if self.ema:
|
| 411 |
+
# self.ema.update(params_with_grad)
|
| 412 |
+
|
| 413 |
+
# super().optimizer_step(*args, **kwargs)
|
| 414 |
+
|
| 415 |
+
def _subs_parameterization(self, logits, xt):
|
| 416 |
+
# log prob at the mask index = - infinity
|
| 417 |
+
logits = logits.logits
|
| 418 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
| 419 |
+
# logits[:, :, self.tokenizer.eos_token_id] += self.neg_infinity
|
| 420 |
+
# logits[:, :, self.tokenizer.cls_token_id] += self.neg_infinity
|
| 421 |
+
|
| 422 |
+
# Normalize the logits such that x.exp() is
|
| 423 |
+
# a probability distribution over vocab_size.
|
| 424 |
+
logits = logits - torch.logsumexp(logits, dim=-1,
|
| 425 |
+
keepdim=True)
|
| 426 |
+
|
| 427 |
+
# Apply updates directly in the logits matrix.
|
| 428 |
+
# For the logits of the unmasked tokens, set all values
|
| 429 |
+
# to -infinity except for the indices corresponding to
|
| 430 |
+
# the unmasked tokens.
|
| 431 |
+
unmasked_indices = (xt != self.mask_index)
|
| 432 |
+
logits[unmasked_indices] = self.neg_infinity
|
| 433 |
+
logits[unmasked_indices, xt[unmasked_indices]] = 0
|
| 434 |
+
return logits
|
| 435 |
+
|
| 436 |
+
def _d3pm_parameterization(self, logits):
|
| 437 |
+
if self.subs_masking:
|
| 438 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
| 439 |
+
logits = logits - torch.logsumexp(logits, dim=-1,
|
| 440 |
+
keepdim=True)
|
| 441 |
+
return logits
|
| 442 |
+
|
| 443 |
+
def _sedd_parameterization(self, logits, xt, sigma):
|
| 444 |
+
esigm1_log = torch.where(
|
| 445 |
+
sigma < 0.5,
|
| 446 |
+
torch.expm1(sigma),
|
| 447 |
+
sigma.exp() - 1).log().to(logits.dtype)
|
| 448 |
+
# logits shape
|
| 449 |
+
# (batch_size, diffusion_model_input_length, vocab_size)
|
| 450 |
+
logits = logits - esigm1_log[:, None, None] - np.log(
|
| 451 |
+
logits.shape[-1] - 1)
|
| 452 |
+
# The below scatter operation sets the log score
|
| 453 |
+
# for the input word to 0.
|
| 454 |
+
logits = torch.scatter(logits, -1, xt[..., None],
|
| 455 |
+
torch.zeros_like(logits[..., :1]))
|
| 456 |
+
return logits
|
| 457 |
+
|
| 458 |
+
def _process_sigma(self, sigma):
|
| 459 |
+
if sigma is None:
|
| 460 |
+
assert self.parameterization == 'ar'
|
| 461 |
+
return sigma
|
| 462 |
+
if sigma.ndim > 1:
|
| 463 |
+
sigma = sigma.squeeze(-1)
|
| 464 |
+
if not self.time_conditioning:
|
| 465 |
+
sigma = torch.zeros_like(sigma)
|
| 466 |
+
assert sigma.ndim == 1, sigma.shape
|
| 467 |
+
return sigma
|
| 468 |
+
|
| 469 |
+
def forward(self, x, sigma, attention_mask, print_logits=False):
|
| 470 |
+
"""Returns log score."""
|
| 471 |
+
sigma = self._process_sigma(sigma)
|
| 472 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 473 |
+
logits = self.backbone(x, attention_mask)
|
| 474 |
+
# if print_logits:
|
| 475 |
+
# torch.set_printoptions(profile="full")
|
| 476 |
+
# print(logits)
|
| 477 |
+
# torch.set_printoptions(profile="default")
|
| 478 |
+
if self.parameterization == 'subs':
|
| 479 |
+
return self._subs_parameterization(logits=logits, xt=x)
|
| 480 |
+
return logits
|
| 481 |
+
|
| 482 |
+
def _d3pm_loss(self, model_output, xt, x0, t, attention_mask):
|
| 483 |
+
dt = 1 / self.T
|
| 484 |
+
|
| 485 |
+
if torch.is_tensor(t):
|
| 486 |
+
t = t[:, None]
|
| 487 |
+
assert t.ndim == 2
|
| 488 |
+
t = t.clamp(0., 1. - 1e-4)
|
| 489 |
+
alpha_t = 1 - t + torch.zeros_like(xt)
|
| 490 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
|
| 491 |
+
|
| 492 |
+
log_x_theta_at_x0 = torch.gather(
|
| 493 |
+
model_output, -1, x0[:, :, None]).squeeze(-1)
|
| 494 |
+
log_x_theta_at_m = model_output[:, :, self.mask_index]
|
| 495 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
| 496 |
+
|
| 497 |
+
term_1_coef = dt / t
|
| 498 |
+
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
|
| 499 |
+
term_1_log_dr = log_x_theta_at_x0
|
| 500 |
+
|
| 501 |
+
term_2_coef = 1 - dt / t
|
| 502 |
+
term_2_log_nr = term_1_log_nr
|
| 503 |
+
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
|
| 504 |
+
|
| 505 |
+
L_vb_masked = (
|
| 506 |
+
term_1_coef * (term_1_log_nr - term_1_log_dr)
|
| 507 |
+
+ term_2_coef * (term_2_log_nr - term_2_log_dr))
|
| 508 |
+
|
| 509 |
+
L_vb = L_vb_masked * (xt == self.mask_index)
|
| 510 |
+
|
| 511 |
+
return self.T * L_vb
|
| 512 |
+
|
| 513 |
+
def _compute_loss(self, batch, prefix):
|
| 514 |
+
if 'attention_mask' in batch:
|
| 515 |
+
attention_mask = batch['attention_mask']
|
| 516 |
+
else:
|
| 517 |
+
attention_mask = None
|
| 518 |
+
if 'mask' in batch: mask = batch['mask']
|
| 519 |
+
else: mask = None
|
| 520 |
+
|
| 521 |
+
losses = self._loss(batch['input_ids'], attention_mask, mask)
|
| 522 |
+
loss = losses.loss
|
| 523 |
+
|
| 524 |
+
if prefix == 'train':
|
| 525 |
+
self.train_metrics.update(losses.nlls, losses.token_mask)
|
| 526 |
+
metrics = self.train_metrics
|
| 527 |
+
elif prefix == 'val':
|
| 528 |
+
self.valid_metrics.update(losses.nlls, losses.token_mask)
|
| 529 |
+
metrics = self.valid_metrics
|
| 530 |
+
elif prefix == 'test':
|
| 531 |
+
self.test_metrics.update(losses.nlls, losses.token_mask)
|
| 532 |
+
metrics = self.test_metrics
|
| 533 |
+
else:
|
| 534 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 535 |
+
|
| 536 |
+
self.log_dict(metrics,
|
| 537 |
+
on_step=False,
|
| 538 |
+
on_epoch=True,
|
| 539 |
+
sync_dist=True)
|
| 540 |
+
return loss
|
| 541 |
+
|
| 542 |
+
def on_train_epoch_start(self):
|
| 543 |
+
self.backbone.train()
|
| 544 |
+
self.noise.train()
|
| 545 |
+
|
| 546 |
+
def training_step(self, batch, batch_idx):
|
| 547 |
+
# Initialize throughput calculation
|
| 548 |
+
start_time = time.time()
|
| 549 |
+
|
| 550 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 551 |
+
self.log(name='trainer/loss',
|
| 552 |
+
value=loss.item(),
|
| 553 |
+
on_step=True,
|
| 554 |
+
on_epoch=False,
|
| 555 |
+
sync_dist=True)
|
| 556 |
+
|
| 557 |
+
# Calculate throughput
|
| 558 |
+
elapsed_time = time.time() - start_time
|
| 559 |
+
total_tokens = batch['input_ids'].numel()
|
| 560 |
+
throughput = total_tokens / elapsed_time
|
| 561 |
+
|
| 562 |
+
self.log(name='trainer/throughput',
|
| 563 |
+
value=throughput,
|
| 564 |
+
on_step=True,
|
| 565 |
+
on_epoch=False,
|
| 566 |
+
sync_dist=True)
|
| 567 |
+
|
| 568 |
+
return loss
|
| 569 |
+
|
| 570 |
+
def on_validation_epoch_start(self):
|
| 571 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 572 |
+
# self.backbone.parameters(),
|
| 573 |
+
# self.noise.parameters()
|
| 574 |
+
# ) if p.requires_grad]
|
| 575 |
+
# if self.ema:
|
| 576 |
+
# self.ema.store(params_with_grad)
|
| 577 |
+
# self.ema.copy_to(params_with_grad)
|
| 578 |
+
|
| 579 |
+
gc.collect()
|
| 580 |
+
torch.cuda.empty_cache()
|
| 581 |
+
if self.ema:
|
| 582 |
+
self.ema.store(
|
| 583 |
+
itertools.chain(
|
| 584 |
+
self.backbone.parameters(),
|
| 585 |
+
self.noise.parameters()))
|
| 586 |
+
self.ema.copy_to(itertools.chain(
|
| 587 |
+
self.backbone.parameters(),
|
| 588 |
+
self.noise.parameters()))
|
| 589 |
+
self.backbone.eval()
|
| 590 |
+
self.noise.eval()
|
| 591 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 592 |
+
assert self.valid_metrics.nll.weight == 0
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def validation_step(self, batch, batch_idx):
|
| 596 |
+
loss = self._compute_loss(batch, prefix='val')
|
| 597 |
+
self.log(name='trainer/val_loss',
|
| 598 |
+
value=loss.item(),
|
| 599 |
+
on_step=True,
|
| 600 |
+
on_epoch=False,
|
| 601 |
+
prog_bar=True,
|
| 602 |
+
sync_dist=True)
|
| 603 |
+
return loss
|
| 604 |
+
|
| 605 |
+
def on_validation_epoch_end(self):
|
| 606 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 607 |
+
# self.backbone.parameters(),
|
| 608 |
+
# self.noise.parameters()
|
| 609 |
+
# ) if p.requires_grad]
|
| 610 |
+
# if ((self.config.eval.compute_perplexity_on_sanity
|
| 611 |
+
# or not self.trainer.sanity_checking)
|
| 612 |
+
# and self.config.eval.generate_samples
|
| 613 |
+
# and not self.parameterization == 'ar'):
|
| 614 |
+
# # (justin): implement sampling and kv cache for AR
|
| 615 |
+
# samples, text_samples = None, None
|
| 616 |
+
# for _ in range(
|
| 617 |
+
# self.config.sampling.num_sample_batches):
|
| 618 |
+
# samples = self._sample()
|
| 619 |
+
# # Decode the samples to be re-tokenized by eval model
|
| 620 |
+
# text_samples = self.tokenizer.batch_decode(samples)
|
| 621 |
+
# if self.config.eval.compute_generative_perplexity:
|
| 622 |
+
# self.compute_generative_perplexity(text_samples)
|
| 623 |
+
# if self.trainer.global_rank == 0 and hasattr(
|
| 624 |
+
# self.trainer.logger, 'log_table'):
|
| 625 |
+
# # Log the last generated samples
|
| 626 |
+
# text_samples = text_samples[
|
| 627 |
+
# : self.config.sampling.num_sample_log]
|
| 628 |
+
# self.trainer.logger.log_table(
|
| 629 |
+
# key=f'samples@global_step{self.global_step}',
|
| 630 |
+
# columns=['Generated Samples'],
|
| 631 |
+
# data=[[s] for s in text_samples])
|
| 632 |
+
# if self.config.eval.compute_generative_perplexity:
|
| 633 |
+
# self.log('val/gen_ppl',
|
| 634 |
+
# self.gen_ppl_metric,
|
| 635 |
+
# on_epoch=True,
|
| 636 |
+
# on_step=False,
|
| 637 |
+
# sync_dist=True)
|
| 638 |
+
|
| 639 |
+
gc.collect()
|
| 640 |
+
torch.cuda.empty_cache()
|
| 641 |
+
if self.ema:
|
| 642 |
+
self.ema.restore(
|
| 643 |
+
itertools.chain(
|
| 644 |
+
self.backbone.parameters(),
|
| 645 |
+
self.noise.parameters()))
|
| 646 |
+
|
| 647 |
+
def test_step(self, batch, batch_idx):
|
| 648 |
+
loss = self._compute_loss(batch, prefix='test')
|
| 649 |
+
self.log('test/loss',
|
| 650 |
+
value=loss.item(),
|
| 651 |
+
on_step=False,
|
| 652 |
+
on_epoch=True,
|
| 653 |
+
sync_dist=True)
|
| 654 |
+
|
| 655 |
+
if self.config.eval.compute_generative_perplexity:
|
| 656 |
+
samples, text_samples = None, None
|
| 657 |
+
for _ in range(
|
| 658 |
+
self.config.sampling.num_sample_batches):
|
| 659 |
+
samples = self._sample()
|
| 660 |
+
# Decode the samples to be re-tokenized by eval model
|
| 661 |
+
text_samples = self.tokenizer.batch_decode(samples)
|
| 662 |
+
if self.config.eval.compute_generative_perplexity:
|
| 663 |
+
self.compute_generative_perplexity(text_samples)
|
| 664 |
+
if self.trainer.global_rank == 0 and hasattr(
|
| 665 |
+
self.trainer.logger, 'log_table'):
|
| 666 |
+
# Log the last generated samples
|
| 667 |
+
text_samples = text_samples[
|
| 668 |
+
: self.config.sampling.num_sample_log]
|
| 669 |
+
self.trainer.logger.log_table(
|
| 670 |
+
key=f'samples@global_step{self.global_step}',
|
| 671 |
+
columns=['Generated Samples'],
|
| 672 |
+
data=[[s] for s in text_samples])
|
| 673 |
+
if self.config.eval.compute_generative_perplexity:
|
| 674 |
+
self.log('test/gen_ppl',
|
| 675 |
+
self.gen_ppl_metric,
|
| 676 |
+
on_epoch=False,
|
| 677 |
+
on_step=True,
|
| 678 |
+
sync_dist=True)
|
| 679 |
+
|
| 680 |
+
def on_test_epoch_start(self):
|
| 681 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 682 |
+
# self.backbone.parameters(),
|
| 683 |
+
# self.noise.parameters()
|
| 684 |
+
# ) if p.requires_grad]
|
| 685 |
+
|
| 686 |
+
if self.ema:
|
| 687 |
+
self.ema.store(itertools.chain(
|
| 688 |
+
self.backbone.parameters(),
|
| 689 |
+
self.noise.parameters()))
|
| 690 |
+
self.ema.copy_to(itertools.chain(
|
| 691 |
+
self.backbone.parameters(),
|
| 692 |
+
self.noise.parameters()))
|
| 693 |
+
|
| 694 |
+
self.backbone.eval()
|
| 695 |
+
self.noise.eval()
|
| 696 |
+
self.test_metrics.reset()
|
| 697 |
+
|
| 698 |
+
def on_test_epoch_end(self):
|
| 699 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 700 |
+
# self.backbone.parameters(),
|
| 701 |
+
# self.noise.parameters()
|
| 702 |
+
# ) if p.requires_grad]
|
| 703 |
+
|
| 704 |
+
if self.ema:
|
| 705 |
+
self.ema.restore(itertools.chain(
|
| 706 |
+
self.backbone.parameters(),
|
| 707 |
+
self.noise.parameters()))
|
| 708 |
+
|
| 709 |
+
for metric_name, metric_value in self.test_metrics.compute().items():
|
| 710 |
+
self.log(metric_name, metric_value, sync_dist=True)
|
| 711 |
+
|
| 712 |
+
def configure_optimizers(self):
|
| 713 |
+
# (yair): Lightning currently giving this warning when using `fp16`:
|
| 714 |
+
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
| 715 |
+
# Not clear if this is a problem or not.
|
| 716 |
+
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
|
| 717 |
+
|
| 718 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 719 |
+
# self.backbone.parameters(),
|
| 720 |
+
# self.noise.parameters()
|
| 721 |
+
# ) if p.requires_grad]
|
| 722 |
+
|
| 723 |
+
optimizer = torch.optim.AdamW(
|
| 724 |
+
itertools.chain(self.backbone.parameters(),
|
| 725 |
+
self.noise.parameters()),
|
| 726 |
+
lr=self.config.optim.lr,
|
| 727 |
+
betas=(self.config.optim.beta1,
|
| 728 |
+
self.config.optim.beta2),
|
| 729 |
+
eps=self.config.optim.eps,
|
| 730 |
+
weight_decay=self.config.optim.weight_decay
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# scheduler = hydra.utils.instantiate(
|
| 734 |
+
# self.config.lr_scheduler, optimizer=optimizer)
|
| 735 |
+
# scheduler_dict = {
|
| 736 |
+
# 'scheduler': scheduler,
|
| 737 |
+
# 'interval': 'step',
|
| 738 |
+
# 'monitor': 'val/loss',
|
| 739 |
+
# 'name': 'trainer/lr',
|
| 740 |
+
# }
|
| 741 |
+
|
| 742 |
+
self.total_steps = self.config.trainer.max_steps
|
| 743 |
+
scheduler = CosineWarmup(optimizer,
|
| 744 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
| 745 |
+
total_steps=self.total_steps)
|
| 746 |
+
|
| 747 |
+
scheduler_dict = {
|
| 748 |
+
'scheduler': scheduler,
|
| 749 |
+
'interval': 'step',
|
| 750 |
+
'frequency': 1,
|
| 751 |
+
'monitor': 'val/loss',
|
| 752 |
+
'name': 'trainer/lr'
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
return [optimizer], [scheduler_dict]
|
| 756 |
+
|
| 757 |
+
@torch.no_grad()
|
| 758 |
+
def eval_retokenize(self, text_samples, max_length):
|
| 759 |
+
"""Retokenizes samples for the eval model.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
text_samples: List of sentences generated by the model.
|
| 763 |
+
Returns:
|
| 764 |
+
samples: Samples re-tokenized for the eval model
|
| 765 |
+
attn_mask: Attention mask for the eval model
|
| 766 |
+
eval_context_size: Size of the context for the eval model
|
| 767 |
+
"""
|
| 768 |
+
if 'llama2' in self.gen_ppl_eval_model_name_or_path:
|
| 769 |
+
tokenizer_kwargs = {
|
| 770 |
+
'text_samples': text_samples,
|
| 771 |
+
'return_tensors': 'pt',
|
| 772 |
+
'return_token_type_ids': False,
|
| 773 |
+
'return_attention_mask': True,
|
| 774 |
+
'truncation': True,
|
| 775 |
+
'padding': True,
|
| 776 |
+
'max_length': max_length,
|
| 777 |
+
}
|
| 778 |
+
eval_context_size = 4096
|
| 779 |
+
else:
|
| 780 |
+
tokenizer_kwargs = {
|
| 781 |
+
'return_tensors': 'pt',
|
| 782 |
+
'return_token_type_ids': False,
|
| 783 |
+
'return_attention_mask': True,
|
| 784 |
+
'truncation': True,
|
| 785 |
+
'padding': True,
|
| 786 |
+
'max_length': max_length,
|
| 787 |
+
}
|
| 788 |
+
eval_context_size = 1024
|
| 789 |
+
samples = self.eval_model_tokenizer(
|
| 790 |
+
text_samples, ** tokenizer_kwargs)
|
| 791 |
+
attn_mask = samples['attention_mask']
|
| 792 |
+
samples = samples['input_ids']
|
| 793 |
+
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
| 794 |
+
attn_mask = attn_mask.to(self.device)
|
| 795 |
+
samples = samples.to(self.device)
|
| 796 |
+
return samples, attn_mask, eval_context_size
|
| 797 |
+
|
| 798 |
+
# @torch.no_grad()
|
| 799 |
+
# def compute_generative_perplexity(
|
| 800 |
+
# self,
|
| 801 |
+
# text_samples: typing.List[str],
|
| 802 |
+
# retokenize: bool = True,
|
| 803 |
+
# max_length: typing.Optional[int] = None) -> None:
|
| 804 |
+
# """Compute the generative perplexity of the model.
|
| 805 |
+
|
| 806 |
+
# Args:
|
| 807 |
+
# text_samples: List of sentences generated by the model.
|
| 808 |
+
|
| 809 |
+
# Returns:
|
| 810 |
+
# Perplexity of the generated text under a different
|
| 811 |
+
# pre-trained AR model (e.g., GPT2).
|
| 812 |
+
# """
|
| 813 |
+
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 814 |
+
# eval_model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 815 |
+
# self.gen_ppl_eval_model_name_or_path).eval()
|
| 816 |
+
# if max_length is None:
|
| 817 |
+
# max_length = self.config.model.length
|
| 818 |
+
# if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
| 819 |
+
# eval_model = eval_model.to(self.device)
|
| 820 |
+
# # Re-tokenize using eval model's tokenizer
|
| 821 |
+
# if retokenize:
|
| 822 |
+
# (samples, attn_mask,
|
| 823 |
+
# eval_context_size) = self.eval_retokenize(
|
| 824 |
+
# text_samples, max_length=max_length)
|
| 825 |
+
# else:
|
| 826 |
+
# samples = text_samples
|
| 827 |
+
# attn_mask = torch.ones(samples.shape).to(self.device)
|
| 828 |
+
# eval_context_size = samples.shape[-1]
|
| 829 |
+
# batch_size = min(
|
| 830 |
+
# self.config.eval.perplexity_batch_size,
|
| 831 |
+
# samples.shape[0])
|
| 832 |
+
# num_batches = samples.shape[0] // batch_size
|
| 833 |
+
# for i in range(num_batches):
|
| 834 |
+
# _samples = torch.split(
|
| 835 |
+
# samples[i * batch_size: (i + 1) * batch_size],
|
| 836 |
+
# eval_context_size,
|
| 837 |
+
# dim=-1)
|
| 838 |
+
# _attn_mask = torch.split(
|
| 839 |
+
# attn_mask[i * batch_size: (i + 1) * batch_size],
|
| 840 |
+
# eval_context_size,
|
| 841 |
+
# dim=-1)
|
| 842 |
+
# for (sample_chunk, attn_mask_chunk) in zip(
|
| 843 |
+
# _samples, _attn_mask):
|
| 844 |
+
# logits = eval_model(
|
| 845 |
+
# sample_chunk, attention_mask=attn_mask_chunk)[0]
|
| 846 |
+
# logits = logits.transpose(-1, -2)
|
| 847 |
+
|
| 848 |
+
# nlls = F.cross_entropy(logits[..., :-1],
|
| 849 |
+
# sample_chunk[..., 1:],
|
| 850 |
+
# reduction='none')
|
| 851 |
+
# first_eos = (sample_chunk == self.eval_model_tokenizer\
|
| 852 |
+
# .eos_token_id).cumsum(-1) == 1
|
| 853 |
+
# token_mask = (
|
| 854 |
+
# sample_chunk
|
| 855 |
+
# != self.eval_model_tokenizer.eos_token_id)
|
| 856 |
+
# self.gen_ppl_metric.update(
|
| 857 |
+
# nlls, first_eos[..., 1:] + token_mask[..., 1:])
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
@torch.no_grad()
|
| 861 |
+
def compute_masked_perplexity(self, sequences, masked):
|
| 862 |
+
"""Compute the pseudo-perplexity of the generated protein sequences."""
|
| 863 |
+
total_nll = 0
|
| 864 |
+
total_tokens = 0
|
| 865 |
+
|
| 866 |
+
for sequence in sequences:
|
| 867 |
+
# Tokenize the sequence
|
| 868 |
+
input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device)
|
| 869 |
+
gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device)
|
| 870 |
+
|
| 871 |
+
# print(input_ids.shape)
|
| 872 |
+
# print(gt_ids.shape)
|
| 873 |
+
|
| 874 |
+
# Forward pass through the ESM model
|
| 875 |
+
attention_mask = torch.ones_like(input_ids)
|
| 876 |
+
if self.config.mode in ['train', 'ppl_eval']:
|
| 877 |
+
outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask)
|
| 878 |
+
elif self.config.mode == "sample_eval":
|
| 879 |
+
outputs = self.backbone.model.forward(input_ids)
|
| 880 |
+
logits = outputs[-1] # B, L, V
|
| 881 |
+
|
| 882 |
+
# Compute loss
|
| 883 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
| 884 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
| 885 |
+
# print(masked)
|
| 886 |
+
# print(gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1))
|
| 887 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
|
| 888 |
+
gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1),
|
| 889 |
+
reduction='sum')
|
| 890 |
+
|
| 891 |
+
total_nll += loss.item()
|
| 892 |
+
#total_tokens += (input_ids != self.tokenizer.pad_token_id).sum().item() - 1 # -1 for the first token
|
| 893 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
| 894 |
+
# Compute pseudo-perplexity
|
| 895 |
+
# print(total_nll, ",;,", total_tokens)
|
| 896 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
| 897 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
| 898 |
+
|
| 899 |
+
return pseudo_perplexity.item()
|
| 900 |
+
|
| 901 |
+
@torch.no_grad()
|
| 902 |
+
def compute_generative_perplexity(
|
| 903 |
+
self,
|
| 904 |
+
text_samples: typing.List[str],
|
| 905 |
+
retokenize: bool = True,
|
| 906 |
+
max_length: typing.Optional[int] = None) -> None:
|
| 907 |
+
"""Compute the generative perplexity of the model.
|
| 908 |
+
|
| 909 |
+
Args:
|
| 910 |
+
text_samples: List of sentences generated by the model.
|
| 911 |
+
|
| 912 |
+
Returns:
|
| 913 |
+
Perplexity of the generated text under a different
|
| 914 |
+
pre-trained AR model (e.g., GPT2).
|
| 915 |
+
"""
|
| 916 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 917 |
+
eval_model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 918 |
+
self.gen_ppl_eval_model_name_or_path).eval()
|
| 919 |
+
if max_length is None:
|
| 920 |
+
max_length = self.config.model.length
|
| 921 |
+
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
|
| 922 |
+
eval_model = eval_model.to(self.device)
|
| 923 |
+
# Re-tokenize using eval model's tokenizer
|
| 924 |
+
if retokenize:
|
| 925 |
+
(samples, attn_mask,
|
| 926 |
+
eval_context_size) = self.eval_retokenize(
|
| 927 |
+
text_samples, max_length=max_length)
|
| 928 |
+
else:
|
| 929 |
+
samples = text_samples
|
| 930 |
+
attn_mask = torch.ones(samples.shape).to(self.device)
|
| 931 |
+
eval_context_size = samples.shape[-1]
|
| 932 |
+
batch_size = min(
|
| 933 |
+
self.config.eval.perplexity_batch_size,
|
| 934 |
+
samples.shape[0])
|
| 935 |
+
num_batches = samples.shape[0] // batch_size
|
| 936 |
+
for i in range(num_batches):
|
| 937 |
+
_samples = torch.split(
|
| 938 |
+
samples[i * batch_size: (i + 1) * batch_size],
|
| 939 |
+
eval_context_size,
|
| 940 |
+
dim=-1)
|
| 941 |
+
_attn_mask = torch.split(
|
| 942 |
+
attn_mask[i * batch_size: (i + 1) * batch_size],
|
| 943 |
+
eval_context_size,
|
| 944 |
+
dim=-1)
|
| 945 |
+
for (sample_chunk, attn_mask_chunk) in zip(
|
| 946 |
+
_samples, _attn_mask):
|
| 947 |
+
logits = eval_model(
|
| 948 |
+
sample_chunk, attention_mask=attn_mask_chunk)[0]
|
| 949 |
+
logits = logits.transpose(-1, -2)
|
| 950 |
+
|
| 951 |
+
nlls = F.cross_entropy(logits[..., :-1],
|
| 952 |
+
sample_chunk[..., 1:],
|
| 953 |
+
reduction='none')
|
| 954 |
+
first_eos = (sample_chunk == self.eval_model_tokenizer\
|
| 955 |
+
.eos_token_id).cumsum(-1) == 1
|
| 956 |
+
token_mask = (
|
| 957 |
+
sample_chunk
|
| 958 |
+
!= self.eval_model_tokenizer.eos_token_id)
|
| 959 |
+
self.gen_ppl_metric.update(
|
| 960 |
+
nlls, first_eos[..., 1:] + token_mask[..., 1:])
|
| 961 |
+
|
| 962 |
+
def q_xt(self, x, move_chance):
|
| 963 |
+
"""Computes the noisy sample xt.
|
| 964 |
+
|
| 965 |
+
Args:
|
| 966 |
+
x: int torch.Tensor with shape (batch_size,
|
| 967 |
+
diffusion_model_input_length), input.
|
| 968 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 969 |
+
"""
|
| 970 |
+
|
| 971 |
+
actual_seq_length = (x != 1).sum(dim=1, keepdim=True)
|
| 972 |
+
|
| 973 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
| 974 |
+
|
| 975 |
+
move_indices = torch.rand(*x.shape, device=x.device) < move_chance
|
| 976 |
+
|
| 977 |
+
restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool)
|
| 978 |
+
|
| 979 |
+
for i in range(x.shape[0]):
|
| 980 |
+
true_positions = torch.where(move_indices[i])[0]
|
| 981 |
+
if len(true_positions) > max_mask_length[i]:
|
| 982 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
| 983 |
+
restricted_move_indices[i, selected_positions] = True
|
| 984 |
+
else:
|
| 985 |
+
restricted_move_indices[i] = move_indices[i]
|
| 986 |
+
xt = torch.where(restricted_move_indices, self.mask_index, x)
|
| 987 |
+
|
| 988 |
+
return xt
|
| 989 |
+
|
| 990 |
+
def _sample_prior(self, *batch_dims):
|
| 991 |
+
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
|
| 992 |
+
|
| 993 |
+
def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None):
|
| 994 |
+
assert self.config.noise.type == 'loglinear'
|
| 995 |
+
sigma_t, _ = self.noise(t)
|
| 996 |
+
if t.ndim > 1:
|
| 997 |
+
t = t.squeeze(-1)
|
| 998 |
+
assert t.ndim == 1
|
| 999 |
+
move_chance_t = t[:, None, None]
|
| 1000 |
+
move_chance_s = (t - dt)[:, None, None]
|
| 1001 |
+
assert move_chance_t.ndim == 3, move_chance_t.shape
|
| 1002 |
+
if p_x0 is None:
|
| 1003 |
+
p_x0 = self.forward(x, sigma_t, attention_mask).exp()
|
| 1004 |
+
|
| 1005 |
+
assert move_chance_t.ndim == p_x0.ndim
|
| 1006 |
+
q_xs = p_x0 * (move_chance_t - move_chance_s)
|
| 1007 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1008 |
+
_x = _sample_categorical(q_xs)
|
| 1009 |
+
|
| 1010 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1011 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * _x
|
| 1012 |
+
|
| 1013 |
+
def _ddpm_update(self, x, t, dt, attention_mask):
|
| 1014 |
+
sigma_t, _ = self.noise(t)
|
| 1015 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1016 |
+
if sigma_t.ndim > 1:
|
| 1017 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1018 |
+
if sigma_s.ndim > 1:
|
| 1019 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1020 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1021 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1022 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
| 1023 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1024 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 1025 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 1026 |
+
unet_conditioning = sigma_t
|
| 1027 |
+
log_p_x0 = self.forward(x, unet_conditioning, attention_mask)
|
| 1028 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 1029 |
+
# Technically, this isn't q_xs since there's a division
|
| 1030 |
+
# term that is missing. This division term doesn't affect
|
| 1031 |
+
# the samples.
|
| 1032 |
+
q_xs = log_p_x0.exp() * (move_chance_t
|
| 1033 |
+
- move_chance_s)
|
| 1034 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1035 |
+
_x = _sample_categorical(q_xs)
|
| 1036 |
+
|
| 1037 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1038 |
+
return copy_flag * x + (1 - copy_flag) * _x
|
| 1039 |
+
|
| 1040 |
+
def _ar_sampler(self, bsz):
|
| 1041 |
+
# precompute token buffer
|
| 1042 |
+
num_pred_tokens = self.config.model.length - 1
|
| 1043 |
+
x = torch.zeros(
|
| 1044 |
+
(bsz, num_pred_tokens + 1),
|
| 1045 |
+
dtype=torch.long,
|
| 1046 |
+
device=self.device)
|
| 1047 |
+
x[:, 0] = self.tokenizer.bos_token_id
|
| 1048 |
+
# precompute noise
|
| 1049 |
+
noise = (torch.distributions.Gumbel(0, 1)
|
| 1050 |
+
.sample((bsz, num_pred_tokens, self.vocab_size))
|
| 1051 |
+
.to(self.device))
|
| 1052 |
+
for i in range(num_pred_tokens):
|
| 1053 |
+
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
|
| 1054 |
+
y = (next_logits + noise[:, i]).argmax(-1)
|
| 1055 |
+
x[:, i + 1] = y
|
| 1056 |
+
return x
|
| 1057 |
+
|
| 1058 |
+
@torch.no_grad()
|
| 1059 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input = None):
|
| 1060 |
+
"""Generate samples from the model."""
|
| 1061 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
| 1062 |
+
if self.parameterization == 'ar':
|
| 1063 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 1064 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 1065 |
+
if num_steps is None:
|
| 1066 |
+
num_steps = self.config.sampling.steps
|
| 1067 |
+
if x_input is not None:
|
| 1068 |
+
x = x_input.input_ids
|
| 1069 |
+
attention_mask = x_input.attention_mask
|
| 1070 |
+
else:
|
| 1071 |
+
x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
| 1072 |
+
attention_mask = torch.ones_like(x)
|
| 1073 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 1074 |
+
dt = (1 - eps) / num_steps
|
| 1075 |
+
p_x0_cache = None
|
| 1076 |
+
|
| 1077 |
+
for i in range(num_steps):
|
| 1078 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
|
| 1079 |
+
if self.sampler == 'ddpm':
|
| 1080 |
+
x = self._ddpm_update(x, t, dt)
|
| 1081 |
+
elif self.sampler == 'ddpm_cache':
|
| 1082 |
+
p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask)
|
| 1083 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 1084 |
+
# Disable caching
|
| 1085 |
+
p_x0_cache = None
|
| 1086 |
+
x = x_next
|
| 1087 |
+
# print(self.tokenizer.decode(x.squeeze()))
|
| 1088 |
+
else:
|
| 1089 |
+
x = self._analytic_update(x, t, dt, attention_mask)
|
| 1090 |
+
|
| 1091 |
+
if self.config.sampling.noise_removal:
|
| 1092 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
| 1093 |
+
device=self.device)
|
| 1094 |
+
if self.sampler == 'analytic':
|
| 1095 |
+
x = self._denoiser_update(x, t)
|
| 1096 |
+
else:
|
| 1097 |
+
unet_conditioning = self.noise(t)[0]
|
| 1098 |
+
x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1)
|
| 1099 |
+
# print(self.tokenizer.decode(x.squeeze()))
|
| 1100 |
+
return x
|
| 1101 |
+
|
| 1102 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
| 1103 |
+
"""Generate samples from the model."""
|
| 1104 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 1105 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 1106 |
+
# self.backbone.parameters(),
|
| 1107 |
+
# self.noise.parameters()
|
| 1108 |
+
# ) if p.requires_grad]
|
| 1109 |
+
|
| 1110 |
+
if self.ema:
|
| 1111 |
+
self.ema.store(itertools.chain(self.backbone.parameters(),
|
| 1112 |
+
self.noise.parameters()))
|
| 1113 |
+
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
|
| 1114 |
+
self.noise.parameters()))
|
| 1115 |
+
self.backbone.eval()
|
| 1116 |
+
self.noise.eval()
|
| 1117 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
| 1118 |
+
if self.ema:
|
| 1119 |
+
self.ema.restore(itertools.chain(self.backbone.parameters(),
|
| 1120 |
+
self.noise.parameters()))
|
| 1121 |
+
self.backbone.train()
|
| 1122 |
+
self.noise.train()
|
| 1123 |
+
return samples
|
| 1124 |
+
|
| 1125 |
+
def get_score(self, x, sigma, attention_mask=None):
|
| 1126 |
+
model_output = self.forward(x, sigma, attention_mask)
|
| 1127 |
+
if self.parameterization == 'subs':
|
| 1128 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 1129 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 1130 |
+
|
| 1131 |
+
# case 1: x = masked
|
| 1132 |
+
# (i) y = unmasked
|
| 1133 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 1134 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1135 |
+
# (ii) y = masked
|
| 1136 |
+
# log score(x, t) = 0
|
| 1137 |
+
|
| 1138 |
+
# case 2: x = unmasked
|
| 1139 |
+
# (i) y != masked, y != x
|
| 1140 |
+
# log score(x_i, t) = - inf
|
| 1141 |
+
# (ii) y = x
|
| 1142 |
+
# log score(x_i, t) = 0
|
| 1143 |
+
# (iii) y = masked token
|
| 1144 |
+
# log score(x_i, t) = - log k
|
| 1145 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1146 |
+
|
| 1147 |
+
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 1148 |
+
assert log_k.ndim == 1
|
| 1149 |
+
|
| 1150 |
+
masked_score = model_output + log_k[:, None, None]
|
| 1151 |
+
masked_score[:, :, self.mask_index] = 0
|
| 1152 |
+
|
| 1153 |
+
unmasked_score = self.neg_infinity * torch.ones_like(
|
| 1154 |
+
model_output)
|
| 1155 |
+
unmasked_score = torch.scatter(
|
| 1156 |
+
unmasked_score,
|
| 1157 |
+
-1,
|
| 1158 |
+
x[..., None],
|
| 1159 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 1160 |
+
unmasked_score[:, :, self.mask_index] = - (
|
| 1161 |
+
log_k[:, None] * torch.ones_like(x))
|
| 1162 |
+
|
| 1163 |
+
masked_indices = (x == self.mask_index).to(
|
| 1164 |
+
model_output.dtype)[:, :, None]
|
| 1165 |
+
model_output = (
|
| 1166 |
+
masked_score * masked_indices
|
| 1167 |
+
+ unmasked_score * (1 - masked_indices))
|
| 1168 |
+
return model_output.exp()
|
| 1169 |
+
|
| 1170 |
+
def _staggered_score(self, score, dsigma):
|
| 1171 |
+
score = score.clone()
|
| 1172 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 1173 |
+
score *= dsigma.exp()[:, None]
|
| 1174 |
+
score[..., self.mask_index] += extra_const
|
| 1175 |
+
return score
|
| 1176 |
+
|
| 1177 |
+
def _analytic_update(self, x, t, step_size, attention_mask=None):
|
| 1178 |
+
curr_sigma, _ = self.noise(t)
|
| 1179 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 1180 |
+
dsigma = curr_sigma - next_sigma
|
| 1181 |
+
score = self.get_score(x, curr_sigma, attention_mask)
|
| 1182 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 1183 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 1184 |
+
return _sample_categorical(probs)
|
| 1185 |
+
|
| 1186 |
+
def _denoiser_update(self, x, t):
|
| 1187 |
+
sigma, _ = self.noise(t)
|
| 1188 |
+
score = self.get_score(x, sigma)
|
| 1189 |
+
stag_score = self._staggered_score(score, sigma)
|
| 1190 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 1191 |
+
probs[..., self.mask_index] = 0
|
| 1192 |
+
samples = _sample_categorical(probs)
|
| 1193 |
+
return samples
|
| 1194 |
+
|
| 1195 |
+
def _transp_transition(self, i, sigma):
|
| 1196 |
+
sigma = _unsqueeze(sigma, reference=i[..., None])
|
| 1197 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 1198 |
+
i, num_classes=self.vocab_size)
|
| 1199 |
+
edge += torch.where(i == self.mask_index,
|
| 1200 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 1201 |
+
0)[..., None]
|
| 1202 |
+
return edge
|
| 1203 |
+
|
| 1204 |
+
def _sample_t(self, n, device):
|
| 1205 |
+
_eps_t = torch.rand(n, device=device)
|
| 1206 |
+
if self.antithetic_sampling:
|
| 1207 |
+
offset = torch.arange(n, device=device) / n
|
| 1208 |
+
_eps_t = (_eps_t / n + offset) % 1
|
| 1209 |
+
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
|
| 1210 |
+
if self.importance_sampling:
|
| 1211 |
+
return self.noise.importance_sampling_transformation(t)
|
| 1212 |
+
return t
|
| 1213 |
+
|
| 1214 |
+
def _maybe_sub_sample(self, x0, attention_mask):
|
| 1215 |
+
# seqlen = x0.shape[1]
|
| 1216 |
+
# if seqlen > self.config.model.length:
|
| 1217 |
+
# assert seqlen == 2 * self.config.model.length
|
| 1218 |
+
# # cropping is needed for text8-crop dataset
|
| 1219 |
+
# # try the same starting point for now
|
| 1220 |
+
# start = np.random.choice(self.config.model.length)
|
| 1221 |
+
# end = start + self.config.model.length
|
| 1222 |
+
# input_tokens = x0[:, start: end]
|
| 1223 |
+
# output_tokens = x0[:, start + 1: end + 1]
|
| 1224 |
+
# new_attention_mask = attention_mask[:, start: end]
|
| 1225 |
+
|
| 1226 |
+
# # Helps with validation PPL, since the val
|
| 1227 |
+
# # examples will all start and end with BOS/EOS
|
| 1228 |
+
# input_tokens[:, 0] = self.tokenizer.bos_token_id
|
| 1229 |
+
# output_tokens[:, -1] = self.tokenizer.eos_token_id
|
| 1230 |
+
# elif self.parameterization == 'ar':
|
| 1231 |
+
# input_tokens = x0[:, :-1]
|
| 1232 |
+
# output_tokens = x0[:, 1:]
|
| 1233 |
+
# new_attention_mask = attention_mask[:, 1:]
|
| 1234 |
+
# else:
|
| 1235 |
+
input_tokens = x0
|
| 1236 |
+
output_tokens = None
|
| 1237 |
+
new_attention_mask = attention_mask
|
| 1238 |
+
return input_tokens, output_tokens, new_attention_mask
|
| 1239 |
+
|
| 1240 |
+
def _reconstruction_loss(self, x0, attention_mask):
|
| 1241 |
+
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
|
| 1242 |
+
device=self.device)
|
| 1243 |
+
assert self.config.noise.type == 'loglinear'
|
| 1244 |
+
# The above assert is for d3pm parameterization
|
| 1245 |
+
unet_conditioning = self.noise(t0)[0][:, None]
|
| 1246 |
+
model_output_t0 = self.forward(x0, unet_conditioning, attention_mask)
|
| 1247 |
+
return - torch.gather(input=model_output_t0,
|
| 1248 |
+
dim=-1,
|
| 1249 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 1250 |
+
|
| 1251 |
+
def _forward_pass_diffusion(self, x0, attention_mask, mask=None):
|
| 1252 |
+
t = self._sample_t(x0.shape[0], x0.device)
|
| 1253 |
+
if self.T > 0:
|
| 1254 |
+
t = (t * self.T).to(torch.int)
|
| 1255 |
+
t = t / self.T
|
| 1256 |
+
# t \in {1/T, 2/T, ..., 1}
|
| 1257 |
+
t += (1 / self.T)
|
| 1258 |
+
|
| 1259 |
+
if self.change_of_variables:
|
| 1260 |
+
unet_conditioning = t[:, None]
|
| 1261 |
+
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
|
| 1262 |
+
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
|
| 1263 |
+
move_chance = torch.exp(f_0 + t * (f_T - f_0))
|
| 1264 |
+
move_chance = move_chance[:, None]
|
| 1265 |
+
else:
|
| 1266 |
+
sigma, dsigma = self.noise(t)
|
| 1267 |
+
unet_conditioning = sigma[:, None]
|
| 1268 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
| 1269 |
+
|
| 1270 |
+
if mask is None: xt = self.q_xt(x0, move_chance)
|
| 1271 |
+
else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id))
|
| 1272 |
+
model_output = self.forward(xt, unet_conditioning, attention_mask)
|
| 1273 |
+
# print(self.tokenizer.decode(torch.argmax(model_output[0], dim=-1)))
|
| 1274 |
+
|
| 1275 |
+
utils.print_nans(model_output, 'model_output')
|
| 1276 |
+
|
| 1277 |
+
if self.parameterization == 'sedd':
|
| 1278 |
+
return dsigma[:, None] * self._score_entropy(
|
| 1279 |
+
model_output, sigma[:, None], xt, x0)
|
| 1280 |
+
|
| 1281 |
+
if self.T > 0:
|
| 1282 |
+
diffusion_loss = self._d3pm_loss(
|
| 1283 |
+
model_output=model_output, xt=xt, x0=x0, t=t)
|
| 1284 |
+
if self.parameterization == 'd3pm':
|
| 1285 |
+
reconstruction_loss = self._reconstruction_loss(x0)
|
| 1286 |
+
elif self.parameterization == 'subs':
|
| 1287 |
+
reconstruction_loss = 0
|
| 1288 |
+
return reconstruction_loss + diffusion_loss
|
| 1289 |
+
|
| 1290 |
+
# SUBS parameterization, continuous time.
|
| 1291 |
+
log_p_theta = torch.gather(
|
| 1292 |
+
input=model_output,
|
| 1293 |
+
dim=-1,
|
| 1294 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 1295 |
+
|
| 1296 |
+
if self.change_of_variables or self.importance_sampling:
|
| 1297 |
+
return log_p_theta * torch.log1p(
|
| 1298 |
+
- torch.exp(- self.noise.sigma_min))
|
| 1299 |
+
|
| 1300 |
+
return - log_p_theta * (
|
| 1301 |
+
dsigma / torch.expm1(sigma))[:, None]
|
| 1302 |
+
|
| 1303 |
+
def _loss(self, x0, attention_mask, mask=None):
|
| 1304 |
+
(input_tokens, output_tokens,
|
| 1305 |
+
attention_mask) = self._maybe_sub_sample(
|
| 1306 |
+
x0, attention_mask)
|
| 1307 |
+
|
| 1308 |
+
if self.parameterization == 'ar':
|
| 1309 |
+
logprobs = self.backbone(input_tokens, None, attention_mask)
|
| 1310 |
+
loss = - logprobs.gather(
|
| 1311 |
+
-1, output_tokens[:, :, None])[:, :, 0]
|
| 1312 |
+
else:
|
| 1313 |
+
loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask)
|
| 1314 |
+
|
| 1315 |
+
nlls = loss * attention_mask
|
| 1316 |
+
count = attention_mask.sum()
|
| 1317 |
+
|
| 1318 |
+
batch_nll = nlls.sum()
|
| 1319 |
+
token_nll = batch_nll / count
|
| 1320 |
+
|
| 1321 |
+
return Loss(loss=token_nll,
|
| 1322 |
+
nlls=nlls,
|
| 1323 |
+
token_mask=attention_mask)
|
| 1324 |
+
|
| 1325 |
+
def _score_entropy(self, log_score, sigma, xt, x0):
|
| 1326 |
+
"""Computes the SEDD loss.
|
| 1327 |
+
|
| 1328 |
+
Args:
|
| 1329 |
+
log_score: float torch.Tensor with shape (batch_size,
|
| 1330 |
+
diffusion_model_input_length, vocab_size),
|
| 1331 |
+
log score, output of the denoising network.
|
| 1332 |
+
xt: int torch.Tensor with shape (batch_size,
|
| 1333 |
+
diffusion_model_input_length), input.
|
| 1334 |
+
x0: int torch.Tensor with shape (batch_size,
|
| 1335 |
+
diffusion_model_input_length), input.
|
| 1336 |
+
sigma: float torch.Tensor with shape (batch_size, 1).
|
| 1337 |
+
|
| 1338 |
+
Returns:
|
| 1339 |
+
loss with shape (batch_size, diffusion_model_input_length)
|
| 1340 |
+
"""
|
| 1341 |
+
masked_indices = xt == self.mask_index
|
| 1342 |
+
|
| 1343 |
+
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
|
| 1344 |
+
q_ratio = 1 / expsig_minus_1[masked_indices]
|
| 1345 |
+
|
| 1346 |
+
words_that_were_masked = x0[masked_indices]
|
| 1347 |
+
|
| 1348 |
+
neg_term = q_ratio * torch.gather(
|
| 1349 |
+
log_score[masked_indices],
|
| 1350 |
+
-1,
|
| 1351 |
+
words_that_were_masked[..., None]).squeeze(-1)
|
| 1352 |
+
score = log_score[masked_indices].exp()
|
| 1353 |
+
if self.mask_index == self.vocab_size - 1:
|
| 1354 |
+
pos_term = score[:, :-1].sum(dim=-1)
|
| 1355 |
+
else:
|
| 1356 |
+
pos_term = score[:, : self.mask_index].sum(
|
| 1357 |
+
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
|
| 1358 |
+
const = q_ratio * (q_ratio.log() - 1)
|
| 1359 |
+
|
| 1360 |
+
entropy = torch.zeros(* xt.shape, device=xt.device)
|
| 1361 |
+
entropy[masked_indices] += pos_term - neg_term + const
|
| 1362 |
+
return entropy
|
| 1363 |
+
|
| 1364 |
+
@torch.no_grad
|
| 1365 |
+
def sample_subs_guidance(
|
| 1366 |
+
self, n_samples, stride_length, num_strides, dt=0.001):
|
| 1367 |
+
ones = torch.ones(n_samples, dtype=self.dtype,
|
| 1368 |
+
device=self.device)
|
| 1369 |
+
|
| 1370 |
+
num_steps = int(1 / dt)
|
| 1371 |
+
sampling_steps = 0
|
| 1372 |
+
intermediate_tokens = []
|
| 1373 |
+
target = None
|
| 1374 |
+
for _ in range(num_strides + 1):
|
| 1375 |
+
p_x0_cache = None
|
| 1376 |
+
x = self._sample_prior(
|
| 1377 |
+
n_samples,
|
| 1378 |
+
self.config.model.length).to(self.device)
|
| 1379 |
+
if target is not None:
|
| 1380 |
+
x[:, : -stride_length] = target
|
| 1381 |
+
for i in range(num_steps + 1):
|
| 1382 |
+
p_x0_cache, x_next = self._ddpm_caching_update(
|
| 1383 |
+
x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
|
| 1384 |
+
if (not torch.allclose(x_next, x)
|
| 1385 |
+
or self.time_conditioning):
|
| 1386 |
+
p_x0_cache = None
|
| 1387 |
+
sampling_steps += 1
|
| 1388 |
+
x = x_next
|
| 1389 |
+
x = self.forward(x, 0 * ones).argmax(dim=-1)
|
| 1390 |
+
intermediate_tokens.append(
|
| 1391 |
+
x[:, :stride_length].cpu().numpy())
|
| 1392 |
+
target = x[:, stride_length:]
|
| 1393 |
+
|
| 1394 |
+
intermediate_tokens.append(target.cpu().numpy())
|
| 1395 |
+
intermediate_text_samples = []
|
| 1396 |
+
sequence_lengths = ((
|
| 1397 |
+
np.concatenate(intermediate_tokens, axis=1)[:, 1:]
|
| 1398 |
+
== self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
|
| 1399 |
+
for i in range(2, len(intermediate_tokens) + 1):
|
| 1400 |
+
intermediate_text_samples.append(
|
| 1401 |
+
self.tokenizer.batch_decode(
|
| 1402 |
+
np.concatenate(intermediate_tokens[:i], axis=1)))
|
| 1403 |
+
return (sampling_steps, intermediate_text_samples,
|
| 1404 |
+
sequence_lengths)
|
| 1405 |
+
|
| 1406 |
+
def restore_model_and_semi_ar_sample(
|
| 1407 |
+
self, stride_length, num_strides, dt=0.001):
|
| 1408 |
+
"""Generate samples from the model."""
|
| 1409 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 1410 |
+
|
| 1411 |
+
# params_with_grad = [p for p in itertools.chain(
|
| 1412 |
+
# self.backbone.parameters(),
|
| 1413 |
+
# self.noise.parameters()
|
| 1414 |
+
# ) if p]
|
| 1415 |
+
|
| 1416 |
+
if self.ema:
|
| 1417 |
+
self.ema.store(itertools.chain(self.backbone.parameters(),
|
| 1418 |
+
self.noise.parameters()))
|
| 1419 |
+
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
|
| 1420 |
+
self.noise.parameters()))
|
| 1421 |
+
self.backbone.eval()
|
| 1422 |
+
self.noise.eval()
|
| 1423 |
+
(sampling_steps, samples,
|
| 1424 |
+
sequence_lengths) = self.sample_subs_guidance(
|
| 1425 |
+
n_samples=self.config.loader.eval_batch_size,
|
| 1426 |
+
stride_length=stride_length,
|
| 1427 |
+
num_strides=num_strides,
|
| 1428 |
+
dt=dt)
|
| 1429 |
+
if self.ema:
|
| 1430 |
+
self.ema.restore(itertools.chain(self.backbone.parameters(),
|
| 1431 |
+
self.noise.parameters()))
|
| 1432 |
+
self.backbone.train()
|
| 1433 |
+
self.noise.train()
|
| 1434 |
+
return sampling_steps, samples, sequence_lengths
|
dit.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import typing
|
| 3 |
+
|
| 4 |
+
import flash_attn
|
| 5 |
+
import flash_attn.layers.rotary
|
| 6 |
+
import huggingface_hub
|
| 7 |
+
import omegaconf
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
from transformers import AutoModel
|
| 14 |
+
|
| 15 |
+
# Flags required to enable jit fusion kernels
|
| 16 |
+
torch._C._jit_set_profiling_mode(False)
|
| 17 |
+
torch._C._jit_set_profiling_executor(False)
|
| 18 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 19 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def bias_dropout_add_scale(
|
| 23 |
+
x: torch.Tensor,
|
| 24 |
+
bias: typing.Optional[torch.Tensor],
|
| 25 |
+
scale: torch.Tensor,
|
| 26 |
+
residual: typing.Optional[torch.Tensor],
|
| 27 |
+
prob: float,
|
| 28 |
+
training: bool) -> torch.Tensor:
|
| 29 |
+
if bias is not None:
|
| 30 |
+
out = scale * F.dropout(x + bias, p=prob, training=training)
|
| 31 |
+
else:
|
| 32 |
+
out = scale * F.dropout(x, p=prob, training=training)
|
| 33 |
+
|
| 34 |
+
if residual is not None:
|
| 35 |
+
out = residual + out
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_bias_dropout_add_scale(training):
|
| 40 |
+
def _bias_dropout_add(x, bias, scale, residual, prob):
|
| 41 |
+
return bias_dropout_add_scale(
|
| 42 |
+
x, bias, scale, residual, prob, training)
|
| 43 |
+
|
| 44 |
+
return _bias_dropout_add
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# function overload
|
| 48 |
+
def modulate(x: torch.Tensor,
|
| 49 |
+
shift: torch.Tensor,
|
| 50 |
+
scale: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
return x * (1 + scale) + shift
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.jit.script
|
| 55 |
+
def bias_dropout_add_scale_fused_train(
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
bias: typing.Optional[torch.Tensor],
|
| 58 |
+
scale: torch.Tensor,
|
| 59 |
+
residual: typing.Optional[torch.Tensor],
|
| 60 |
+
prob: float) -> torch.Tensor:
|
| 61 |
+
return bias_dropout_add_scale(
|
| 62 |
+
x, bias, scale, residual, prob, True)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@torch.jit.script
|
| 66 |
+
def bias_dropout_add_scale_fused_inference(
|
| 67 |
+
x: torch.Tensor,
|
| 68 |
+
bias: typing.Optional[torch.Tensor],
|
| 69 |
+
scale: torch.Tensor,
|
| 70 |
+
residual: typing.Optional[torch.Tensor],
|
| 71 |
+
prob: float) -> torch.Tensor:
|
| 72 |
+
return bias_dropout_add_scale(
|
| 73 |
+
x, bias, scale, residual, prob, False)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@torch.jit.script
|
| 77 |
+
def modulate_fused(x: torch.Tensor,
|
| 78 |
+
shift: torch.Tensor,
|
| 79 |
+
scale: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
return modulate(x, shift, scale)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Rotary(torch.nn.Module):
|
| 84 |
+
def __init__(self, dim, base=10_000):
|
| 85 |
+
super().__init__()
|
| 86 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 87 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 88 |
+
self.seq_len_cached = None
|
| 89 |
+
self.cos_cached = None
|
| 90 |
+
self.sin_cached = None
|
| 91 |
+
|
| 92 |
+
def forward(self, x, seq_dim=1):
|
| 93 |
+
seq_len = x.shape[seq_dim]
|
| 94 |
+
if seq_len != self.seq_len_cached:
|
| 95 |
+
self.seq_len_cached = seq_len
|
| 96 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
|
| 97 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
|
| 98 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 99 |
+
# dims are: batch, seq_len, qkv, head, dim
|
| 100 |
+
self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
|
| 101 |
+
self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
|
| 102 |
+
# This makes the transformation on v an identity.
|
| 103 |
+
self.cos_cached[:,:,2,:,:].fill_(1.)
|
| 104 |
+
self.sin_cached[:,:,2,:,:].fill_(0.)
|
| 105 |
+
|
| 106 |
+
return self.cos_cached, self.sin_cached
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def rotate_half(x):
|
| 110 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 111 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def apply_rotary_pos_emb(qkv, cos, sin):
|
| 115 |
+
cos = cos[0,:,0,0,:cos.shape[-1]//2]
|
| 116 |
+
sin = sin[0,:,0,0,:sin.shape[-1]//2]
|
| 117 |
+
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# function overload
|
| 121 |
+
def modulate(x, shift, scale):
|
| 122 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
#################################################################################
|
| 126 |
+
# Layers #
|
| 127 |
+
#################################################################################
|
| 128 |
+
class LayerNorm(nn.Module):
|
| 129 |
+
def __init__(self, dim):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.weight = nn.Parameter(torch.ones([dim]))
|
| 132 |
+
self.dim = dim
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 135 |
+
x = F.layer_norm(x.float(), [self.dim])
|
| 136 |
+
return x * self.weight[None,None,:]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def residual_linear(x, W, x_skip, residual_scale):
|
| 140 |
+
"""x_skip + residual_scale * W @ x"""
|
| 141 |
+
dim_out, dim_in = W.shape[0], W.shape[1]
|
| 142 |
+
return torch.addmm(
|
| 143 |
+
x_skip.view(-1, dim_out),
|
| 144 |
+
x.view(-1, dim_in),
|
| 145 |
+
W.T,
|
| 146 |
+
alpha=residual_scale).view(*x.shape[:-1], dim_out)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
#################################################################################
|
| 150 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 151 |
+
#################################################################################
|
| 152 |
+
class TimestepEmbedder(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
Embeds scalar timesteps into vector representations.
|
| 155 |
+
"""
|
| 156 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.mlp = nn.Sequential(
|
| 159 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 160 |
+
nn.SiLU(),
|
| 161 |
+
nn.Linear(hidden_size, hidden_size, bias=True))
|
| 162 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 166 |
+
"""
|
| 167 |
+
Create sinusoidal timestep embeddings.
|
| 168 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 169 |
+
These may be fractional.
|
| 170 |
+
:param dim: the dimension of the output.
|
| 171 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 172 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 173 |
+
"""
|
| 174 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 175 |
+
half = dim // 2
|
| 176 |
+
freqs = torch.exp(- math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
| 177 |
+
|
| 178 |
+
if t.ndim == 1:
|
| 179 |
+
t = t.unsqueeze(1)
|
| 180 |
+
|
| 181 |
+
args = t.float() * freqs[None, :]
|
| 182 |
+
#args = t[:, None].float() * freqs[None]
|
| 183 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 184 |
+
if dim % 2:
|
| 185 |
+
embedding = torch.cat(
|
| 186 |
+
[embedding,
|
| 187 |
+
torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 188 |
+
return embedding
|
| 189 |
+
|
| 190 |
+
def forward(self, t):
|
| 191 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 192 |
+
t_emb = self.mlp(t_freq)
|
| 193 |
+
return t_emb
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LabelEmbedder(nn.Module):
|
| 197 |
+
"""Embeds class labels into vector representations.
|
| 198 |
+
|
| 199 |
+
Also handles label dropout for classifier-free guidance.
|
| 200 |
+
"""
|
| 201 |
+
def __init__(self, num_classes, cond_size):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
|
| 204 |
+
self.num_classes = num_classes
|
| 205 |
+
|
| 206 |
+
# TODO think of initializing with 0.02 std deviation like in original DiT paper
|
| 207 |
+
|
| 208 |
+
def forward(self, labels):
|
| 209 |
+
embeddings = self.embedding_table(labels)
|
| 210 |
+
return embeddings
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
#################################################################################
|
| 214 |
+
# Core Model #
|
| 215 |
+
#################################################################################
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class DDiTBlock(nn.Module):
|
| 219 |
+
def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.n_heads = n_heads
|
| 222 |
+
|
| 223 |
+
self.norm1 = LayerNorm(dim)
|
| 224 |
+
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
|
| 225 |
+
self.attn_out = nn.Linear(dim, dim, bias=False)
|
| 226 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 227 |
+
|
| 228 |
+
self.norm2 = LayerNorm(dim)
|
| 229 |
+
self.mlp = nn.Sequential(
|
| 230 |
+
nn.Linear(dim, mlp_ratio * dim, bias=True),
|
| 231 |
+
nn.GELU(approximate='tanh'),
|
| 232 |
+
nn.Linear(mlp_ratio * dim, dim, bias=True))
|
| 233 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 234 |
+
self.dropout = dropout
|
| 235 |
+
|
| 236 |
+
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
|
| 237 |
+
self.adaLN_modulation.weight.data.zero_()
|
| 238 |
+
self.adaLN_modulation.bias.data.zero_()
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _get_bias_dropout_scale(self):
|
| 242 |
+
if self.training:
|
| 243 |
+
return bias_dropout_add_scale_fused_train
|
| 244 |
+
else:
|
| 245 |
+
return bias_dropout_add_scale_fused_inference
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def forward(self, x, rotary_cos_sin, c, seqlens=None):
|
| 249 |
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
| 250 |
+
|
| 251 |
+
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
| 252 |
+
|
| 253 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
| 254 |
+
scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
|
| 255 |
+
|
| 256 |
+
# attention operation
|
| 257 |
+
x_skip = x
|
| 258 |
+
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
| 259 |
+
|
| 260 |
+
qkv = self.attn_qkv(x)
|
| 261 |
+
qkv = rearrange(qkv,
|
| 262 |
+
'b s (three h d) -> b s three h d',
|
| 263 |
+
three=3,
|
| 264 |
+
h=self.n_heads)
|
| 265 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 266 |
+
cos, sin = rotary_cos_sin
|
| 267 |
+
qkv = apply_rotary_pos_emb(
|
| 268 |
+
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
|
| 269 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
| 270 |
+
if seqlens is None:
|
| 271 |
+
cu_seqlens = torch.arange(
|
| 272 |
+
0, (batch_size + 1) * seq_len, step=seq_len,
|
| 273 |
+
dtype=torch.int32, device=qkv.device)
|
| 274 |
+
else:
|
| 275 |
+
cu_seqlens = seqlens.cumsum(-1)
|
| 276 |
+
x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
|
| 277 |
+
qkv, cu_seqlens, seq_len, 0., causal=False)
|
| 278 |
+
|
| 279 |
+
x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
|
| 280 |
+
|
| 281 |
+
x = bias_dropout_scale_fn(self.attn_out(x),
|
| 282 |
+
None,
|
| 283 |
+
gate_msa,
|
| 284 |
+
x_skip,
|
| 285 |
+
self.dropout)
|
| 286 |
+
|
| 287 |
+
# mlp operation
|
| 288 |
+
x = bias_dropout_scale_fn(
|
| 289 |
+
self.mlp(modulate_fused(
|
| 290 |
+
self.norm2(x), shift_mlp, scale_mlp)),
|
| 291 |
+
None, gate_mlp, x, self.dropout)
|
| 292 |
+
return x
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class EmbeddingLayer(nn.Module):
|
| 297 |
+
def __init__(self, dim, vocab_dim):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
|
| 300 |
+
torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
|
| 301 |
+
|
| 302 |
+
def forward(self, x):
|
| 303 |
+
return self.embedding[x]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class DDitFinalLayer(nn.Module):
|
| 307 |
+
def __init__(self, hidden_size, out_channels, cond_dim):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.norm_final = LayerNorm(hidden_size)
|
| 310 |
+
self.linear = nn.Linear(hidden_size, out_channels)
|
| 311 |
+
self.linear.weight.data.zero_()
|
| 312 |
+
self.linear.bias.data.zero_()
|
| 313 |
+
|
| 314 |
+
self.adaLN_modulation = nn.Linear(cond_dim,
|
| 315 |
+
2 * hidden_size,
|
| 316 |
+
bias=True)
|
| 317 |
+
self.adaLN_modulation.weight.data.zero_()
|
| 318 |
+
self.adaLN_modulation.bias.data.zero_()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def forward(self, x, c):
|
| 322 |
+
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
|
| 323 |
+
x = modulate_fused(self.norm_final(x), shift, scale)
|
| 324 |
+
x = self.linear(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
|
| 329 |
+
def __init__(self, config, vocab_size: int, mlm_model_path):
|
| 330 |
+
super().__init__()
|
| 331 |
+
if type(config) == dict:
|
| 332 |
+
config = omegaconf.OmegaConf.create(config)
|
| 333 |
+
|
| 334 |
+
self.config = config
|
| 335 |
+
self.vocab_size = vocab_size
|
| 336 |
+
|
| 337 |
+
self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
|
| 338 |
+
vocab_size)
|
| 339 |
+
self.sigma_map = TimestepEmbedder(config.model.cond_dim)
|
| 340 |
+
self.rotary_emb = Rotary(
|
| 341 |
+
config.model.hidden_size // config.model.n_heads)
|
| 342 |
+
|
| 343 |
+
blocks = []
|
| 344 |
+
for _ in range(config.model.n_blocks):
|
| 345 |
+
blocks.append(DDiTBlock(config.model.hidden_size,
|
| 346 |
+
config.model.n_heads,
|
| 347 |
+
config.model.cond_dim,
|
| 348 |
+
dropout=config.model.dropout))
|
| 349 |
+
self.blocks = nn.ModuleList(blocks)
|
| 350 |
+
|
| 351 |
+
self.output_layer = DDitFinalLayer(
|
| 352 |
+
config.model.hidden_size,
|
| 353 |
+
vocab_size,
|
| 354 |
+
config.model.cond_dim)
|
| 355 |
+
self.scale_by_sigma = config.model.scale_by_sigma
|
| 356 |
+
|
| 357 |
+
self.mlm_model = AutoModel.from_pretrained(mlm_model_path, device_map='cpu')
|
| 358 |
+
|
| 359 |
+
def _get_bias_dropout_scale(self):
|
| 360 |
+
if self.training:
|
| 361 |
+
return bias_dropout_add_scale_fused_train
|
| 362 |
+
else:
|
| 363 |
+
return bias_dropout_add_scale_fused_inference
|
| 364 |
+
|
| 365 |
+
def forward(self, indices, sigma):
|
| 366 |
+
x = self.vocab_embed(indices)
|
| 367 |
+
c_sigma = F.silu(self.sigma_map(sigma))
|
| 368 |
+
|
| 369 |
+
rotary_cos_sin = self.rotary_emb(x)
|
| 370 |
+
|
| 371 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 372 |
+
for i in range(len(self.blocks)):
|
| 373 |
+
x = self.blocks[i](x, rotary_cos_sin, c_sigma, seqlens=None)
|
| 374 |
+
x = self.output_layer(x, c_sigma)
|
| 375 |
+
|
| 376 |
+
# Extract membrane-specific embeddings from final encoder layer
|
| 377 |
+
# of fine-tuned ESM model
|
| 378 |
+
# with torch.no_grad():
|
| 379 |
+
# membrane_embedding = self.mlm_model(input_ids=, attention_mask=).last_hidden_state.squeeze(0)
|
| 380 |
+
|
| 381 |
+
# Fuse MLM embeddings with conditioning vector
|
| 382 |
+
# c = torch.cat([c_sigma, membrane_embedding], dim=-1)
|
| 383 |
+
|
| 384 |
+
# print(membrane_embedding.size())
|
| 385 |
+
# print(c_sigma.size())
|
| 386 |
+
|
| 387 |
+
return x
|
| 388 |
+
|
ema.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExponentialMovingAverage:
|
| 5 |
+
"""
|
| 6 |
+
Maintains (exponential) moving average of a set of parameters.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, parameters, decay, use_num_updates=True):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
| 13 |
+
`model.parameters()`.
|
| 14 |
+
decay: The exponential decay.
|
| 15 |
+
use_num_updates: Whether to use number of updates when computing
|
| 16 |
+
averages.
|
| 17 |
+
"""
|
| 18 |
+
if decay < 0.0 or decay > 1.0:
|
| 19 |
+
raise ValueError('Decay must be between 0 and 1')
|
| 20 |
+
self.decay = decay
|
| 21 |
+
self.num_updates = 0 if use_num_updates else None
|
| 22 |
+
self.shadow_params = [p.clone().detach()
|
| 23 |
+
for p in parameters if p.requires_grad]
|
| 24 |
+
self.collected_params = []
|
| 25 |
+
|
| 26 |
+
def move_shadow_params_to_device(self, device):
|
| 27 |
+
self.shadow_params = [i.to(device) for i in self.shadow_params]
|
| 28 |
+
|
| 29 |
+
def update(self, parameters):
|
| 30 |
+
"""
|
| 31 |
+
Update currently maintained parameters.
|
| 32 |
+
|
| 33 |
+
Call this every time the parameters are updated, such as the result of
|
| 34 |
+
the `optimizer.step()` call.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
| 38 |
+
parameters used to initialize this object.
|
| 39 |
+
"""
|
| 40 |
+
decay = self.decay
|
| 41 |
+
if self.num_updates is not None:
|
| 42 |
+
self.num_updates += 1
|
| 43 |
+
decay = min(decay, (1 + self.num_updates) /
|
| 44 |
+
(10 + self.num_updates))
|
| 45 |
+
one_minus_decay = 1.0 - decay
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
parameters = [p for p in parameters if p.requires_grad]
|
| 48 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 49 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 50 |
+
|
| 51 |
+
def copy_to(self, parameters):
|
| 52 |
+
"""
|
| 53 |
+
Copy current parameters into given collection of parameters.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 57 |
+
updated with the stored moving averages.
|
| 58 |
+
"""
|
| 59 |
+
parameters = [p for p in parameters if p.requires_grad]
|
| 60 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 61 |
+
if param.requires_grad:
|
| 62 |
+
param.data.copy_(s_param.data)
|
| 63 |
+
|
| 64 |
+
def store(self, parameters):
|
| 65 |
+
"""
|
| 66 |
+
Save the current parameters for restoring later.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 70 |
+
temporarily stored.
|
| 71 |
+
"""
|
| 72 |
+
self.collected_params = [param.clone() for param in parameters]
|
| 73 |
+
|
| 74 |
+
def restore(self, parameters):
|
| 75 |
+
"""
|
| 76 |
+
Restore the parameters stored with the `store` method.
|
| 77 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 78 |
+
original optimization process. Store the parameters before the
|
| 79 |
+
`copy_to` method. After validation (or model saving), use this to
|
| 80 |
+
restore the former parameters.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 84 |
+
updated with the stored parameters.
|
| 85 |
+
"""
|
| 86 |
+
for c_param, param in zip(self.collected_params, parameters):
|
| 87 |
+
param.data.copy_(c_param.data)
|
| 88 |
+
|
| 89 |
+
def state_dict(self):
|
| 90 |
+
return dict(decay=self.decay,
|
| 91 |
+
num_updates=self.num_updates,
|
| 92 |
+
shadow_params=self.shadow_params)
|
| 93 |
+
|
| 94 |
+
def load_state_dict(self, state_dict):
|
| 95 |
+
self.decay = state_dict['decay']
|
| 96 |
+
self.num_updates = state_dict['num_updates']
|
| 97 |
+
self.shadow_params = state_dict['shadow_params']
|
esm_utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import config
|
| 3 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
|
| 4 |
+
|
| 5 |
+
def load_esm2_model(model_name):
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 7 |
+
masked_model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 8 |
+
embedding_model = AutoModel.from_pretrained(model_name)
|
| 9 |
+
return tokenizer, masked_model, embedding_model
|
| 10 |
+
|
| 11 |
+
def get_latents(model, tokenizer, sequence, device):
|
| 12 |
+
inputs = tokenizer(sequence, return_tensors="pt").to(device)
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
outputs = model(**inputs).last_hidden_state.squeeze(0)
|
| 15 |
+
return outputs
|
generate.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from mlm_generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
|
| 8 |
+
from diffusion import Diffusion
|
| 9 |
+
import hydra
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def generate_sequence(sequence_length: int, tokenizer, mdlm: Diffusion):
|
| 16 |
+
global masked_sequence
|
| 17 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
| 18 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
|
| 19 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
| 20 |
+
generated_sequence = tokenizer.decode(logits.squeeze())
|
| 21 |
+
|
| 22 |
+
return generated_sequence
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
| 26 |
+
def mdlm_motif_benchmark(config):
|
| 27 |
+
path = "/workspace/sg666/MDpLM"
|
| 28 |
+
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
|
| 30 |
+
mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
|
| 31 |
+
|
| 32 |
+
mdlm_model.eval()
|
| 33 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 34 |
+
mdlm_model.to(device)
|
| 35 |
+
|
| 36 |
+
print("loaded models...")
|
| 37 |
+
|
| 38 |
+
# Get 100 random sequence lengths to generate
|
| 39 |
+
sequence_lengths = [random.randint(50, 1000) for _ in range(100)]
|
| 40 |
+
|
| 41 |
+
generation_results = []
|
| 42 |
+
for seq_length in tqdm(sequence_lengths, desc=f"Generating sequences: "):
|
| 43 |
+
generated_sequence = generate_sequence(seq_length, tokenizer, mdlm_model)
|
| 44 |
+
generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
|
| 45 |
+
|
| 46 |
+
perplexity = mdlm_model.compute_masked_perplexity([generated_sequence], masked_sequence)
|
| 47 |
+
perplexity = round(perplexity, 4)
|
| 48 |
+
|
| 49 |
+
generation_results.append([generated_sequence, perplexity])
|
| 50 |
+
|
| 51 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | generated sequence: {generated_sequence}")
|
| 52 |
+
sys.stdout.flush()
|
| 53 |
+
|
| 54 |
+
df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'Perplexity'])
|
| 55 |
+
df.to_csv(path + f'/benchmarks/mdlm_de-novo_generation_results.csv', index=False)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
mdlm_motif_benchmark()
|
main.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import wandb
|
| 4 |
+
import fsspec
|
| 5 |
+
import hydra
|
| 6 |
+
import lightning as L
|
| 7 |
+
import omegaconf
|
| 8 |
+
import rich.syntax
|
| 9 |
+
import rich.tree
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import pl_data_loader as dataloader
|
| 13 |
+
from diffusion import Diffusion
|
| 14 |
+
import utils
|
| 15 |
+
|
| 16 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
from datasets import load_from_disk, load_dataset
|
| 19 |
+
|
| 20 |
+
#wandb.login(key="2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f")
|
| 21 |
+
omegaconf.OmegaConf.register_new_resolver(
|
| 22 |
+
'cwd', os.getcwd)
|
| 23 |
+
omegaconf.OmegaConf.register_new_resolver(
|
| 24 |
+
'device_count', torch.cuda.device_count)
|
| 25 |
+
omegaconf.OmegaConf.register_new_resolver(
|
| 26 |
+
'eval', eval)
|
| 27 |
+
omegaconf.OmegaConf.register_new_resolver(
|
| 28 |
+
'div_up', lambda x, y: (x + y - 1) // y)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _load_from_checkpoint(config, tokenizer):
|
| 32 |
+
if 'hf' in config.backbone:
|
| 33 |
+
return Diffusion(
|
| 34 |
+
config, tokenizer=tokenizer).to('cuda')
|
| 35 |
+
else:
|
| 36 |
+
model= Diffusion.load_from_checkpoint(
|
| 37 |
+
config.eval.checkpoint_path,
|
| 38 |
+
tokenizer=tokenizer,
|
| 39 |
+
config=config)
|
| 40 |
+
|
| 41 |
+
return model
|
| 42 |
+
|
| 43 |
+
@L.pytorch.utilities.rank_zero_only
|
| 44 |
+
def _print_config(
|
| 45 |
+
config: omegaconf.DictConfig,
|
| 46 |
+
resolve: bool = True,
|
| 47 |
+
save_cfg: bool = True) -> None:
|
| 48 |
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
config (DictConfig): Configuration composed by Hydra.
|
| 52 |
+
resolve (bool): Whether to resolve reference fields of DictConfig.
|
| 53 |
+
save_cfg (bool): Whether to save the configuration tree to a file.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
style = 'dim'
|
| 57 |
+
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
|
| 58 |
+
|
| 59 |
+
fields = config.keys()
|
| 60 |
+
for field in fields:
|
| 61 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 62 |
+
|
| 63 |
+
config_section = config.get(field)
|
| 64 |
+
branch_content = str(config_section)
|
| 65 |
+
if isinstance(config_section, omegaconf.DictConfig):
|
| 66 |
+
branch_content = omegaconf.OmegaConf.to_yaml(
|
| 67 |
+
config_section, resolve=resolve)
|
| 68 |
+
|
| 69 |
+
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
|
| 70 |
+
rich.print(tree)
|
| 71 |
+
if save_cfg:
|
| 72 |
+
with fsspec.open(
|
| 73 |
+
'{}/config_tree.txt'.format(
|
| 74 |
+
config.checkpointing.save_dir), 'w') as fp:
|
| 75 |
+
rich.print(tree, file=fp)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@L.pytorch.utilities.rank_zero_only
|
| 79 |
+
def _print_batch(train_ds, valid_ds, tokenizer, k=64):
|
| 80 |
+
#for dl_type, dl in [
|
| 81 |
+
#('train', train_ds), ('valid', valid_ds)]:
|
| 82 |
+
for dl_type, dl in [
|
| 83 |
+
('train', train_ds)]:
|
| 84 |
+
print(f'Printing {dl_type} dataloader batch.')
|
| 85 |
+
batch = next(iter(dl))
|
| 86 |
+
print('Batch input_ids.shape', batch['input_ids'].shape)
|
| 87 |
+
first = batch['input_ids'][0, :k]
|
| 88 |
+
last = batch['input_ids'][0, -k:]
|
| 89 |
+
print(f'First {k} tokens:', tokenizer.decode(first))
|
| 90 |
+
print('ids:', first)
|
| 91 |
+
print(f'Last {k} tokens:', tokenizer.decode(last))
|
| 92 |
+
print('ids:', last)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def generate_samples(config, logger, tokenizer):
|
| 96 |
+
logger.info('Generating samples.')
|
| 97 |
+
model = _load_from_checkpoint(config=config,
|
| 98 |
+
tokenizer=tokenizer)
|
| 99 |
+
model.gen_ppl_metric.reset()
|
| 100 |
+
if config.eval.disable_ema:
|
| 101 |
+
logger.info('Disabling EMA.')
|
| 102 |
+
model.ema = None
|
| 103 |
+
stride_length = config.sampling.stride_length
|
| 104 |
+
num_strides = config.sampling.num_strides
|
| 105 |
+
for _ in range(config.sampling.num_sample_batches):
|
| 106 |
+
if config.sampling.semi_ar:
|
| 107 |
+
_, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
|
| 108 |
+
stride_length=stride_length,
|
| 109 |
+
num_strides=num_strides,
|
| 110 |
+
dt=1 / config.sampling.steps)
|
| 111 |
+
text_samples = intermediate_samples[-1]
|
| 112 |
+
# Note: Samples generated using semi-ar method
|
| 113 |
+
# need to to be processed before computing generative perplexity
|
| 114 |
+
# since these samples contain numerous <|endoftext|> tokens
|
| 115 |
+
# and diffusion.compute_generative_perplexity() discards
|
| 116 |
+
# any text after the first EOS token.
|
| 117 |
+
else:
|
| 118 |
+
samples = model.restore_model_and_sample(
|
| 119 |
+
num_steps=config.sampling.steps)
|
| 120 |
+
text_samples = model.tokenizer.batch_decode(samples)
|
| 121 |
+
model.compute_generative_perplexity(text_samples)
|
| 122 |
+
print('Text samples:', text_samples)
|
| 123 |
+
if not config.sampling.semi_ar:
|
| 124 |
+
print('Generative perplexity:',
|
| 125 |
+
model.gen_ppl_metric.compute())
|
| 126 |
+
return text_samples
|
| 127 |
+
|
| 128 |
+
def _ppl_eval(config, logger, tokenizer, data_module):
|
| 129 |
+
logger.info('Starting Zero Shot Eval.')
|
| 130 |
+
|
| 131 |
+
model = _load_from_checkpoint(config=config,
|
| 132 |
+
tokenizer=tokenizer)
|
| 133 |
+
if config.eval.disable_ema:
|
| 134 |
+
logger.info('Disabling EMA.')
|
| 135 |
+
model.ema = None
|
| 136 |
+
|
| 137 |
+
wandb_logger = None
|
| 138 |
+
if config.get('wandb', None) is not None:
|
| 139 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 140 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 141 |
+
** config.wandb)
|
| 142 |
+
callbacks = []
|
| 143 |
+
if 'callbacks' in config:
|
| 144 |
+
for _, callback in config.callbacks.items():
|
| 145 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
| 146 |
+
trainer = hydra.utils.instantiate(
|
| 147 |
+
config.trainer,
|
| 148 |
+
default_root_dir=os.getcwd(),
|
| 149 |
+
callbacks=callbacks,
|
| 150 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
| 151 |
+
logger=wandb_logger)
|
| 152 |
+
# _, valid_ds = dataloader.get_dataloaders(
|
| 153 |
+
# config, tokenizer, skip_train=True, valid_seed=config.seed)
|
| 154 |
+
trainer.test(model, data_module)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _train(config, logger, tokenizer, data_module):
|
| 158 |
+
logger.info('Starting Training.')
|
| 159 |
+
wandb_logger = None
|
| 160 |
+
if config.get('wandb', None) is not None:
|
| 161 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 162 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 163 |
+
** config.wandb)
|
| 164 |
+
|
| 165 |
+
if (config.checkpointing.resume_from_ckpt
|
| 166 |
+
and config.checkpointing.resume_ckpt_path is not None
|
| 167 |
+
and utils.fsspec_exists(
|
| 168 |
+
config.checkpointing.resume_ckpt_path)):
|
| 169 |
+
ckpt_path = config.checkpointing.resume_ckpt_path
|
| 170 |
+
else:
|
| 171 |
+
ckpt_path = None
|
| 172 |
+
|
| 173 |
+
# Lightning callbacks
|
| 174 |
+
callbacks = []
|
| 175 |
+
if 'callbacks' in config:
|
| 176 |
+
for _, callback in config.callbacks.items():
|
| 177 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
| 178 |
+
'''
|
| 179 |
+
train_ds, valid_ds = dataloader.get_dataloaders(
|
| 180 |
+
config, tokenizer)
|
| 181 |
+
_print_batch(train_ds, valid_ds, tokenizer)
|
| 182 |
+
|
| 183 |
+
model = diffusion.Diffusion(
|
| 184 |
+
config, tokenizer=valid_ds.tokenizer)
|
| 185 |
+
'''
|
| 186 |
+
trainer = hydra.utils.instantiate(
|
| 187 |
+
config.trainer,
|
| 188 |
+
default_root_dir=os.getcwd(),
|
| 189 |
+
callbacks=callbacks,
|
| 190 |
+
accelerator='cuda',
|
| 191 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
| 192 |
+
logger=wandb_logger)
|
| 193 |
+
|
| 194 |
+
model = Diffusion(
|
| 195 |
+
config, tokenizer=tokenizer)
|
| 196 |
+
|
| 197 |
+
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
|
| 198 |
+
|
| 199 |
+
'''
|
| 200 |
+
trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
|
| 201 |
+
'''
|
| 202 |
+
|
| 203 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
| 204 |
+
def main(config):
|
| 205 |
+
"""Main entry point for training."""
|
| 206 |
+
L.seed_everything(config.seed)
|
| 207 |
+
_print_config(config, resolve=True, save_cfg=True)
|
| 208 |
+
|
| 209 |
+
logger = utils.get_logger(__name__)
|
| 210 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 211 |
+
|
| 212 |
+
if config.backbone == "vanilla_esm_pretrain":
|
| 213 |
+
train_dataset = load_dataset('csv', data_files=config.data.train.vanilla_esm_train_path)
|
| 214 |
+
val_dataset = load_dataset('csv', data_files=config.data.valid.vanilla_esm_valid_path)
|
| 215 |
+
test_dataset = load_dataset('csv', data_files=config.data.test.vanilla_esm_test_path)
|
| 216 |
+
elif config.backbone == "membrane_esm_finetune" or config.backbone == "dit":
|
| 217 |
+
train_dataset = load_dataset('csv', data_files=config.data.train.membrane_esm_train_path)
|
| 218 |
+
val_dataset = load_dataset('csv', data_files=config.data.valid.membrane_esm_valid_path)
|
| 219 |
+
test_dataset = load_dataset('csv', data_files=config.data.test.membrane_esm_test_path)
|
| 220 |
+
|
| 221 |
+
lst = [i for i in range(1, 200)]
|
| 222 |
+
|
| 223 |
+
train_dataset = train_dataset['train']#.select(lst)
|
| 224 |
+
val_dataset = val_dataset['train']#.select(lst)
|
| 225 |
+
test_dataset = test_dataset['train']#.select(lst)
|
| 226 |
+
|
| 227 |
+
if config.training.focus_mask :
|
| 228 |
+
collator = dataloader.membrane_collate_fn
|
| 229 |
+
elif config.data.wrapping:
|
| 230 |
+
collator = dataloader.wrap_collate_fn
|
| 231 |
+
else:
|
| 232 |
+
collator = collate_fn
|
| 233 |
+
|
| 234 |
+
data_module = dataloader.CustomDataModule(
|
| 235 |
+
train_dataset, val_dataset, test_dataset,
|
| 236 |
+
tokenizer,
|
| 237 |
+
batch_size=config.loader.batch_size,
|
| 238 |
+
collate_fn=collator
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if config.mode == 'sample_eval':
|
| 242 |
+
generate_samples(config, logger, tokenizer)
|
| 243 |
+
elif config.mode == 'ppl_eval':
|
| 244 |
+
_ppl_eval(config, logger, tokenizer, data_module)
|
| 245 |
+
else:
|
| 246 |
+
_train(config, logger, tokenizer, data_module)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == '__main__':
|
| 250 |
+
main()
|
mdlm_motif_benchmarking.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from mlm_generate_utils import mask_for_scaffold, calculate_cosine_sim, calculate_hamming_dist
|
| 8 |
+
from diffusion import Diffusion
|
| 9 |
+
import hydra
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
| 12 |
+
|
| 13 |
+
def masking_test(sequence: str, generate_case: str, tokenizer, mask_prob: float = 0.50):
|
| 14 |
+
"""
|
| 15 |
+
Masks 50% of the tokens in the sequence.
|
| 16 |
+
"""
|
| 17 |
+
tokens = list(sequence.upper())
|
| 18 |
+
num_tokens_to_mask = int(mask_prob * len(tokens)) # Select some fraction of the tokens
|
| 19 |
+
print(num_tokens_to_mask,len(tokens))
|
| 20 |
+
|
| 21 |
+
# Get random indices to mask
|
| 22 |
+
mask_indices = random.sample(range(len(tokens)), num_tokens_to_mask)
|
| 23 |
+
|
| 24 |
+
for idx in mask_indices:
|
| 25 |
+
tokens[idx] = tokenizer.mask_token # Replace with mask token
|
| 26 |
+
|
| 27 |
+
return ''.join(tokens)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def generate_scaffold_mdlm(sequence: str, generate_case: str, tokenizer, mdlm: Diffusion):
|
| 33 |
+
# # Mask soluble or transmembrane domains
|
| 34 |
+
# masked_sequence = mask_for_scaffold(sequence, generate_case)
|
| 35 |
+
|
| 36 |
+
# # Test out different masking rates
|
| 37 |
+
# masked_sequence = masking_test(sequence, generate_case, tokenizer)
|
| 38 |
+
|
| 39 |
+
# 100% masking rate, de novo generation
|
| 40 |
+
masked_sequence = len(sequence) * "<mask>"
|
| 41 |
+
|
| 42 |
+
print(masked_sequence)
|
| 43 |
+
|
| 44 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt").to(mdlm.device)
|
| 45 |
+
|
| 46 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
| 47 |
+
# logits = mdlm.forward(inputs)
|
| 48 |
+
# print(tokenizer.decode(logits.squeeze(), skip_special_tokens=True))
|
| 49 |
+
|
| 50 |
+
return tokenizer.decode(logits.squeeze()), masked_sequence
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@hydra.main(version_base=None, config_path='configs', config_name='config')
|
| 54 |
+
def mdlm_motif_benchmark(config):
|
| 55 |
+
path = "/workspace/sg666/MDpLM"
|
| 56 |
+
|
| 57 |
+
test_sequences = pd.read_csv(path + "/data/membrane/test.csv")['Sequence'].tolist()
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 59 |
+
|
| 60 |
+
mdlm_model = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer)
|
| 61 |
+
esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D") # model used for functionality testing
|
| 62 |
+
|
| 63 |
+
mdlm_model.eval()
|
| 64 |
+
esm_model.eval()
|
| 65 |
+
|
| 66 |
+
print("loaded models...")
|
| 67 |
+
|
| 68 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 69 |
+
mdlm_model.to(device)
|
| 70 |
+
esm_model.to(device)
|
| 71 |
+
|
| 72 |
+
for generate_case in ['uppercase', 'lowercase']:
|
| 73 |
+
case_results = []
|
| 74 |
+
for original_sequence in tqdm(test_sequences, desc=f"scaffolding ({generate_case}): "):
|
| 75 |
+
|
| 76 |
+
generated_sequence, masked_input = generate_scaffold_mdlm(original_sequence, generate_case, tokenizer, mdlm_model)
|
| 77 |
+
generated_sequence = generated_sequence[5:-5].replace(" ", "") # Remove bos/eos tokens
|
| 78 |
+
|
| 79 |
+
perplexity = mdlm_model.compute_masked_perplexity([original_sequence], masked_input)
|
| 80 |
+
cos_sim = calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device)
|
| 81 |
+
hamming_distance = calculate_hamming_dist(original_sequence, generated_sequence)
|
| 82 |
+
|
| 83 |
+
case_results.append([original_sequence, generated_sequence, perplexity, cos_sim, hamming_distance])
|
| 84 |
+
|
| 85 |
+
print("perplexity: ", perplexity, "cos sim: ", cos_sim, "hamming: ", hamming_distance)
|
| 86 |
+
print(f"generated sequence: {generated_sequence}")
|
| 87 |
+
print(f"original sequence: {original_sequence.upper()}")
|
| 88 |
+
sys.stdout.flush()
|
| 89 |
+
|
| 90 |
+
df = pd.DataFrame(case_results, columns=['Original Sequence', 'Generated Sequence', 'Perplexity', 'Cosine Similarity', 'Hamming Distance'])
|
| 91 |
+
df.to_csv(path + f'/benchmarks/MLM/mlm_{generate_case}_results.csv', index=False)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
mdlm_motif_benchmark()
|
mlm_generate_utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import config
|
| 4 |
+
import sys
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from esm_utils import get_latents
|
| 7 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def mask_for_de_novo(sequence_length):
|
| 11 |
+
return "<mask>" * sequence_length
|
| 12 |
+
|
| 13 |
+
def generate_de_novo(sequence_length, tokenizer, model):
|
| 14 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
| 15 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
logits = model(**inputs).logits
|
| 19 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
| 20 |
+
logits_at_masks = logits[0, mask_token_indices]
|
| 21 |
+
|
| 22 |
+
pred_tokens = []
|
| 23 |
+
for i in mask_token_indices:
|
| 24 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
| 25 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
| 26 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
| 27 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
| 28 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
| 29 |
+
pred_tokens.append(predicted_token)
|
| 30 |
+
|
| 31 |
+
generated_sequence = ''.join(pred_tokens)
|
| 32 |
+
perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
|
| 33 |
+
|
| 34 |
+
return (generated_sequence, perplexity)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def mask_for_scaffold(sequence, generate_type):
|
| 38 |
+
if generate_type == "uppercase":
|
| 39 |
+
sequence = ''.join(["<mask>" if residue.isupper() else residue.upper() for residue in sequence])
|
| 40 |
+
elif generate_type == "lowercase":
|
| 41 |
+
sequence = ''.join(["<mask>" if residue.islower() else residue for residue in sequence])
|
| 42 |
+
return sequence
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def generate_scaffold(sequence, generate_type, tokenizer, model):
|
| 46 |
+
masked_sequence = mask_for_scaffold(sequence, generate_type)
|
| 47 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
| 48 |
+
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
logits = model(**inputs).logits
|
| 51 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
| 52 |
+
logits_at_masks = logits[0, mask_token_indices]
|
| 53 |
+
|
| 54 |
+
pred_tokens = []
|
| 55 |
+
for i in range(len(mask_token_indices)):
|
| 56 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
| 57 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
| 58 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
| 59 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
| 60 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
| 61 |
+
|
| 62 |
+
pred_tokens.append('G' if predicted_token == '' else predicted_token)
|
| 63 |
+
|
| 64 |
+
generated_sequence = masked_sequence
|
| 65 |
+
for token in pred_tokens:
|
| 66 |
+
generated_sequence = generated_sequence.replace("<mask>", token, 1)
|
| 67 |
+
|
| 68 |
+
return generated_sequence, mask_token_indices
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
|
| 72 |
+
total_loss = 0.0
|
| 73 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
| 74 |
+
|
| 75 |
+
for i in mask_token_indices:
|
| 76 |
+
masked_input = tensor_input.clone()
|
| 77 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
| 78 |
+
|
| 79 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
| 80 |
+
labels[0, i] = tensor_input[0, i]
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
outputs = model(masked_input, labels=labels)
|
| 84 |
+
total_loss += outputs.loss.item()
|
| 85 |
+
|
| 86 |
+
num_mask_tokens = len(mask_token_indices)
|
| 87 |
+
if num_mask_tokens == 0:
|
| 88 |
+
perplexity = 10000
|
| 89 |
+
else:
|
| 90 |
+
avg_loss = total_loss / num_mask_tokens
|
| 91 |
+
perplexity = math.exp(avg_loss)
|
| 92 |
+
|
| 93 |
+
return perplexity
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
|
| 97 |
+
og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
|
| 98 |
+
new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
|
| 99 |
+
|
| 100 |
+
sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
|
| 101 |
+
cosine_similarity = torch.mean(sequence_similarity).item()
|
| 102 |
+
return cosine_similarity
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def calculate_hamming_dist(original_sequence, generated_sequence):
|
| 106 |
+
generated_sequence = generated_sequence.upper()
|
| 107 |
+
original_sequence = original_sequence.upper()
|
| 108 |
+
return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
|
noise_schedule.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
# Flags required to enable jit fusion kernels
|
| 7 |
+
torch._C._jit_set_profiling_mode(False)
|
| 8 |
+
torch._C._jit_set_profiling_executor(False)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 10 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_noise(config, dtype=torch.float32):
|
| 14 |
+
return LogLinearNoise()
|
| 15 |
+
|
| 16 |
+
if config.noise.type == 'geometric':
|
| 17 |
+
return GeometricNoise(config.noise.sigma_min,
|
| 18 |
+
config.noise.sigma_max)
|
| 19 |
+
elif config.noise.type == 'loglinear':
|
| 20 |
+
return LogLinearNoise()
|
| 21 |
+
elif config.noise.type == 'cosine':
|
| 22 |
+
return CosineNoise()
|
| 23 |
+
elif config.noise.type == 'cosinesqr':
|
| 24 |
+
return CosineSqrNoise()
|
| 25 |
+
elif config.noise.type == 'linear':
|
| 26 |
+
return Linear(config.noise.sigma_min,
|
| 27 |
+
config.noise.sigma_max,
|
| 28 |
+
dtype)
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def binary_discretization(z):
|
| 34 |
+
z_hard = torch.sign(z)
|
| 35 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 36 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Noise(abc.ABC, nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 42 |
+
"""
|
| 43 |
+
def forward(self, t):
|
| 44 |
+
# Assume time goes from 0 to 1
|
| 45 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 46 |
+
|
| 47 |
+
@abc.abstractmethod
|
| 48 |
+
def rate_noise(self, t):
|
| 49 |
+
"""
|
| 50 |
+
Rate of change of noise ie g(t)
|
| 51 |
+
"""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
@abc.abstractmethod
|
| 55 |
+
def total_noise(self, t):
|
| 56 |
+
"""
|
| 57 |
+
Total noise ie \int_0^t g(t) dt + g(0)
|
| 58 |
+
"""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CosineNoise(Noise):
|
| 63 |
+
def __init__(self, eps=1e-3):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.eps = eps
|
| 66 |
+
|
| 67 |
+
def rate_noise(self, t):
|
| 68 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 69 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 70 |
+
scale = torch.pi / 2
|
| 71 |
+
return scale * sin / (cos + self.eps)
|
| 72 |
+
|
| 73 |
+
def total_noise(self, t):
|
| 74 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 75 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class CosineSqrNoise(Noise):
|
| 79 |
+
def __init__(self, eps=1e-3):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.eps = eps
|
| 82 |
+
|
| 83 |
+
def rate_noise(self, t):
|
| 84 |
+
cos = (1 - self.eps) * (
|
| 85 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 86 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 87 |
+
scale = torch.pi / 2
|
| 88 |
+
return scale * sin / (cos + self.eps)
|
| 89 |
+
|
| 90 |
+
def total_noise(self, t):
|
| 91 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 92 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Linear(Noise):
|
| 96 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 99 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 100 |
+
|
| 101 |
+
def rate_noise(self, t):
|
| 102 |
+
return self.sigma_max - self.sigma_min
|
| 103 |
+
|
| 104 |
+
def total_noise(self, t):
|
| 105 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 106 |
+
|
| 107 |
+
def importance_sampling_transformation(self, t):
|
| 108 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 109 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 110 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 111 |
+
return (sigma_t - self.sigma_min) / (
|
| 112 |
+
self.sigma_max - self.sigma_min)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class GeometricNoise(Noise):
|
| 116 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 119 |
+
|
| 120 |
+
def rate_noise(self, t):
|
| 121 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 122 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 123 |
+
|
| 124 |
+
def total_noise(self, t):
|
| 125 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class LogLinearNoise(Noise):
|
| 129 |
+
"""Log Linear noise schedule.
|
| 130 |
+
|
| 131 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 132 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 133 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 134 |
+
(1 - eps) * t.
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self, eps=1e-3):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.eps = eps
|
| 139 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 140 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 141 |
+
|
| 142 |
+
def rate_noise(self, t):
|
| 143 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 144 |
+
|
| 145 |
+
def total_noise(self, t):
|
| 146 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 147 |
+
|
| 148 |
+
def importance_sampling_transformation(self, t):
|
| 149 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 150 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 151 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 152 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 153 |
+
return t
|
pl_data_loader.py
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import typing
|
| 9 |
+
import urllib
|
| 10 |
+
import zipfile
|
| 11 |
+
|
| 12 |
+
import datasets
|
| 13 |
+
import fsspec
|
| 14 |
+
import requests
|
| 15 |
+
import tokenizers
|
| 16 |
+
import torch
|
| 17 |
+
import transformers
|
| 18 |
+
|
| 19 |
+
import utils
|
| 20 |
+
|
| 21 |
+
LOGGER = utils.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def wt_detokenizer(string):
|
| 25 |
+
# contractions
|
| 26 |
+
string = string.replace("s '", "s'")
|
| 27 |
+
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
| 28 |
+
# number separators
|
| 29 |
+
string = string.replace(" @-@ ", "-")
|
| 30 |
+
string = string.replace(" @,@ ", ",")
|
| 31 |
+
string = string.replace(" @.@ ", ".")
|
| 32 |
+
# punctuation
|
| 33 |
+
string = string.replace(" : ", ": ")
|
| 34 |
+
string = string.replace(" ; ", "; ")
|
| 35 |
+
string = string.replace(" . ", ". ")
|
| 36 |
+
string = string.replace(" ! ", "! ")
|
| 37 |
+
string = string.replace(" ? ", "? ")
|
| 38 |
+
string = string.replace(" , ", ", ")
|
| 39 |
+
# double brackets
|
| 40 |
+
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
| 41 |
+
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
| 42 |
+
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
| 43 |
+
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
| 44 |
+
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
| 45 |
+
# miscellaneous
|
| 46 |
+
string = string.replace("= = = =", "====")
|
| 47 |
+
string = string.replace("= = =", "===")
|
| 48 |
+
string = string.replace("= =", "==")
|
| 49 |
+
string = string.replace(" " + chr(176) + " ", chr(176))
|
| 50 |
+
string = string.replace(" \n", "\n")
|
| 51 |
+
string = string.replace("\n ", "\n")
|
| 52 |
+
string = string.replace(" N ", " 1 ")
|
| 53 |
+
string = string.replace(" 's", "'s")
|
| 54 |
+
return string
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def ptb_detokenizer(x):
|
| 58 |
+
x = x.replace(" 's", "'s")
|
| 59 |
+
x = x.replace("s ' ", "s' ")
|
| 60 |
+
x = x.replace(" n't", "n't")
|
| 61 |
+
x = x.replace(" \n ", "\n")
|
| 62 |
+
x = x.replace("\\/", "/")
|
| 63 |
+
for _ in range(10):
|
| 64 |
+
x = x.replace(" N ", " 1 ")
|
| 65 |
+
x = x.replace("$ 1", "$1")
|
| 66 |
+
x = x.replace("# 1", "#1")
|
| 67 |
+
x = x.replace("<unk>", "?")
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def lm1b_detokenizer(x):
|
| 72 |
+
x = x.replace('http : / / ', 'http://')
|
| 73 |
+
x = x.replace('https : / / ', 'https://')
|
| 74 |
+
x = re.sub(r' \'(\w+)', r"'\1", x)
|
| 75 |
+
x = re.sub(r' (\w+) \. ', r' \1. ', x)
|
| 76 |
+
x = re.sub(r' (\w+) \.$', r' \1.', x)
|
| 77 |
+
x = x.replace(' ? ', '? ')
|
| 78 |
+
x = re.sub(r' \?$', '?', x)
|
| 79 |
+
x = x.replace(' ! ', '! ')
|
| 80 |
+
x = re.sub(r' \!$', '!', x)
|
| 81 |
+
x = x.replace(' , ', ', ')
|
| 82 |
+
x = x.replace(' : ', ': ')
|
| 83 |
+
x = x.replace(' ; ', '; ')
|
| 84 |
+
x = x.replace(' / ', '/')
|
| 85 |
+
x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
|
| 86 |
+
x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
|
| 87 |
+
x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
|
| 88 |
+
x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
|
| 89 |
+
x = x.replace('$ ', '$')
|
| 90 |
+
x = x.replace('£ ', '£')
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def lambada_detokenizer(text):
|
| 95 |
+
text = text.replace("“", '"')
|
| 96 |
+
text = text.replace("”", '"')
|
| 97 |
+
return '\n'+text.strip()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def scientific_papers_detokenizer(x):
|
| 101 |
+
x = wt_detokenizer(x)
|
| 102 |
+
x = lm1b_detokenizer(x)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Text8Tokenizer(transformers.PreTrainedTokenizer):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
bos_token='[BOS]',
|
| 110 |
+
eos_token='[EOS]',
|
| 111 |
+
sep_token='[SEP]',
|
| 112 |
+
cls_token='[CLS]',
|
| 113 |
+
pad_token='[PAD]',
|
| 114 |
+
mask_token='[MASK]',
|
| 115 |
+
unk_token='[UNK]',
|
| 116 |
+
**kwargs):
|
| 117 |
+
self.characters = list('abcdefghijklmnopqrstuvwxyz ')
|
| 118 |
+
self._vocab_str_to_int = {
|
| 119 |
+
'[CLS]': 0,
|
| 120 |
+
'[SEP]': 1,
|
| 121 |
+
'[BOS]': 2,
|
| 122 |
+
'[EOS]': 3,
|
| 123 |
+
'[MASK]': 4,
|
| 124 |
+
'[PAD]': 5,
|
| 125 |
+
'[RESERVED]': 6,
|
| 126 |
+
'[UNK]': 7,
|
| 127 |
+
** {ch: i + 8 for i, ch in enumerate(self.characters)}}
|
| 128 |
+
self._vocab_int_to_str = {
|
| 129 |
+
v: k for k, v in self._vocab_str_to_int.items()}
|
| 130 |
+
super().__init__(
|
| 131 |
+
bos_token=bos_token,
|
| 132 |
+
eos_token=eos_token,
|
| 133 |
+
sep_token=sep_token,
|
| 134 |
+
cls_token=cls_token,
|
| 135 |
+
pad_token=pad_token,
|
| 136 |
+
mask_token=mask_token,
|
| 137 |
+
unk_token=unk_token,
|
| 138 |
+
**kwargs)
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def vocab_size(self) -> int:
|
| 142 |
+
return len(self._vocab_str_to_int)
|
| 143 |
+
|
| 144 |
+
def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
|
| 145 |
+
return list(text.lower())
|
| 146 |
+
|
| 147 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 148 |
+
return self._vocab_str_to_int.get(
|
| 149 |
+
token, self._vocab_str_to_int['[UNK]'])
|
| 150 |
+
|
| 151 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 152 |
+
return self._vocab_int_to_str[index]
|
| 153 |
+
|
| 154 |
+
def convert_tokens_to_string(self, tokens):
|
| 155 |
+
return ''.join(tokens)
|
| 156 |
+
|
| 157 |
+
def get_vocab(self) -> typing.Dict[str, int]:
|
| 158 |
+
return self._vocab_str_to_int
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_lambada_test_dataset():
|
| 162 |
+
url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"
|
| 163 |
+
|
| 164 |
+
def read_jsonl_to_list(url):
|
| 165 |
+
response = requests.get(url, stream=True)
|
| 166 |
+
data_list = []
|
| 167 |
+
|
| 168 |
+
# Process each line in the response content
|
| 169 |
+
for line in response.iter_lines(decode_unicode=True):
|
| 170 |
+
if line:
|
| 171 |
+
data = json.loads(line)
|
| 172 |
+
data_list.append(data)
|
| 173 |
+
|
| 174 |
+
return data_list
|
| 175 |
+
|
| 176 |
+
lambada_data = read_jsonl_to_list(url)
|
| 177 |
+
dataset = datasets.Dataset.from_list(lambada_data)
|
| 178 |
+
return dataset
|
| 179 |
+
|
| 180 |
+
def get_text8_dataset(cache_dir, max_seq_length=256,
|
| 181 |
+
drop_last=True, crop_train=False):
|
| 182 |
+
"""Adapted from:
|
| 183 |
+
https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
cache_dir: str, path to cache directory.
|
| 187 |
+
max_seq_length: int, maximum length of sequences.
|
| 188 |
+
(default: 256, as in D3PM codebase.)
|
| 189 |
+
drop_last: bool, whether to drop the last incomplete
|
| 190 |
+
batch. (default: True, as in D3PM codebase.)
|
| 191 |
+
crop_train: bool, whether to subsample contiguous
|
| 192 |
+
subsequences from training example. serves to
|
| 193 |
+
make sure transformer models with absolute position
|
| 194 |
+
embeddings do not have incorrect position-wise
|
| 195 |
+
marginals. (default: False, but necessary to match D3PM AR)
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
dataset: dataset.DatasetDict, with keys 'train',
|
| 199 |
+
'valid', 'test'.
|
| 200 |
+
"""
|
| 201 |
+
url = 'http://mattmahoney.net/dc/text8.zip'
|
| 202 |
+
if not crop_train:
|
| 203 |
+
cache_dir = f'{cache_dir}/text8'
|
| 204 |
+
else:
|
| 205 |
+
cache_dir = f'{cache_dir}/text8-crop-train'
|
| 206 |
+
split_names = ['train', 'validation', 'test']
|
| 207 |
+
if not all([
|
| 208 |
+
utils.fsspec_exists(os.path.join(cache_dir, split))
|
| 209 |
+
for split in split_names
|
| 210 |
+
]):
|
| 211 |
+
# Check if raw data exists
|
| 212 |
+
raw_cache_dir = os.path.join(cache_dir, 'raw_data')
|
| 213 |
+
if not all([
|
| 214 |
+
utils.fsspec_exists(
|
| 215 |
+
os.path.join(raw_cache_dir, f'text8.{split}.txt'))
|
| 216 |
+
for split in split_names
|
| 217 |
+
]):
|
| 218 |
+
if not utils.fsspec_exists(
|
| 219 |
+
os.path.join(raw_cache_dir, 'text8.zip')):
|
| 220 |
+
utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
|
| 221 |
+
LOGGER.info('Downloading text8 from URL {}.'.format(url))
|
| 222 |
+
with urllib.request.urlopen(url) as in_stream:
|
| 223 |
+
with open(os.path.join(raw_cache_dir, 'text8.zip'), 'wb') as out_file:
|
| 224 |
+
shutil.copyfileobj(in_stream, out_file)
|
| 225 |
+
|
| 226 |
+
with fsspec.open(
|
| 227 |
+
os.path.join(raw_cache_dir, 'text8.zip'),
|
| 228 |
+
'rb') as f:
|
| 229 |
+
rawdata = zipfile.ZipFile(f).read(
|
| 230 |
+
'text8').decode('utf-8')
|
| 231 |
+
|
| 232 |
+
# Splits taken from D3PM codebase
|
| 233 |
+
splits = {
|
| 234 |
+
'train': rawdata[:90000000],
|
| 235 |
+
'validation': rawdata[90000000: 95000000],
|
| 236 |
+
'test': rawdata[95000000:],
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
for split, data in splits.items():
|
| 240 |
+
_path = os.path.join(raw_cache_dir,
|
| 241 |
+
f'text8.{split}.txt')
|
| 242 |
+
with fsspec.open(_path, 'w') as f:
|
| 243 |
+
f.write(data)
|
| 244 |
+
else:
|
| 245 |
+
splits = {}
|
| 246 |
+
for split in split_names:
|
| 247 |
+
_path = os.path.join(raw_cache_dir,
|
| 248 |
+
f'text8.{split}.txt')
|
| 249 |
+
with fsspec.open(_path, 'r') as f:
|
| 250 |
+
splits[split] = f.read()
|
| 251 |
+
|
| 252 |
+
# Chunk and save as datasets.DatasetDict
|
| 253 |
+
def chunks(lst, n):
|
| 254 |
+
"""Yield successive n-sized chunks from lst."""
|
| 255 |
+
for i in range(0, len(lst), n):
|
| 256 |
+
yield lst[i:i + n]
|
| 257 |
+
|
| 258 |
+
dataset_dict = {}
|
| 259 |
+
for k, v in splits.items():
|
| 260 |
+
if k == 'train' and crop_train == True:
|
| 261 |
+
chunk_size = 2 * max_seq_length
|
| 262 |
+
else:
|
| 263 |
+
chunk_size = max_seq_length
|
| 264 |
+
text = list(chunks(v, chunk_size))
|
| 265 |
+
if drop_last and len(text[-1]) < chunk_size:
|
| 266 |
+
text = text[:-1]
|
| 267 |
+
dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
|
| 268 |
+
dataset = datasets.DatasetDict(dataset_dict)
|
| 269 |
+
dataset.save_to_disk(cache_dir)
|
| 270 |
+
else:
|
| 271 |
+
dataset = datasets.load_from_disk(cache_dir)
|
| 272 |
+
|
| 273 |
+
return dataset
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _group_texts(examples, block_size, bos, eos):
|
| 277 |
+
# Concatenate all texts.
|
| 278 |
+
concatenated_examples = list(itertools.chain(* examples['input_ids']))
|
| 279 |
+
total_length = len(concatenated_examples)
|
| 280 |
+
# TODO(yair): look into not dropping the remainder but rather padding it.
|
| 281 |
+
# We drop the small remainder, and if the total_length < block_size - 2
|
| 282 |
+
# we exclude this batch and return an empty dict.
|
| 283 |
+
# We could add padding if the model supported it instead of
|
| 284 |
+
# this drop, you can customize this part to your needs.
|
| 285 |
+
new_block_size = block_size - 2 # [BOS] and [EOS] to be added
|
| 286 |
+
total_length = (total_length // new_block_size) * new_block_size
|
| 287 |
+
# Split by chunks of max_len.
|
| 288 |
+
result = {}
|
| 289 |
+
_values = []
|
| 290 |
+
_attn_masks = []
|
| 291 |
+
for i in range(0, total_length, new_block_size):
|
| 292 |
+
_values.append(
|
| 293 |
+
[bos]
|
| 294 |
+
+ concatenated_examples[i : i + new_block_size]
|
| 295 |
+
+ [eos])
|
| 296 |
+
_attn_masks.append(torch.ones(block_size))
|
| 297 |
+
result['input_ids'] = _values
|
| 298 |
+
result['attention_mask'] = _attn_masks
|
| 299 |
+
return result
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def get_dataset(
|
| 303 |
+
dataset_name, tokenizer, wrap, mode, cache_dir,
|
| 304 |
+
block_size=1024, num_proc=len(os.sched_getaffinity(0)), streaming=False):
|
| 305 |
+
if wrap:
|
| 306 |
+
filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped.dat'
|
| 307 |
+
else:
|
| 308 |
+
filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped.dat'
|
| 309 |
+
_path = os.path.join(cache_dir, filename)
|
| 310 |
+
|
| 311 |
+
if utils.fsspec_exists(_path):
|
| 312 |
+
LOGGER.info(f'Loading data from: {_path}')
|
| 313 |
+
return datasets.load_from_disk(_path).with_format('torch')
|
| 314 |
+
LOGGER.info(f'Generating new data at: {_path}')
|
| 315 |
+
|
| 316 |
+
crop_train = dataset_name == 'text8-crop'
|
| 317 |
+
if mode == 'train' and crop_train:
|
| 318 |
+
# double block size for sub-sampling
|
| 319 |
+
block_size *= 2
|
| 320 |
+
|
| 321 |
+
if dataset_name == 'wikitext103':
|
| 322 |
+
dataset = datasets.load_dataset(
|
| 323 |
+
'wikitext',
|
| 324 |
+
name='wikitext-103-raw-v1',
|
| 325 |
+
cache_dir=cache_dir)
|
| 326 |
+
elif dataset_name == 'wikitext2':
|
| 327 |
+
dataset = datasets.load_dataset(
|
| 328 |
+
'wikitext',
|
| 329 |
+
name='wikitext-2-raw-v1',
|
| 330 |
+
cache_dir=cache_dir)
|
| 331 |
+
elif dataset_name == 'ptb':
|
| 332 |
+
dataset = datasets.load_dataset(
|
| 333 |
+
'ptb_text_only', cache_dir=cache_dir)
|
| 334 |
+
elif dataset_name == 'lambada':
|
| 335 |
+
dataset = get_lambada_test_dataset()
|
| 336 |
+
elif dataset_name == 'text8':
|
| 337 |
+
assert wrap
|
| 338 |
+
dataset = get_text8_dataset(
|
| 339 |
+
cache_dir, max_seq_length=block_size)
|
| 340 |
+
elif dataset_name == 'text8-crop':
|
| 341 |
+
dataset = get_text8_dataset(
|
| 342 |
+
cache_dir, max_seq_length=block_size, crop_train=True)
|
| 343 |
+
elif dataset_name == 'openwebtext-train':
|
| 344 |
+
dataset = datasets.load_dataset(
|
| 345 |
+
'openwebtext',
|
| 346 |
+
split='train[:-100000]',
|
| 347 |
+
cache_dir=cache_dir,
|
| 348 |
+
streaming=streaming)
|
| 349 |
+
elif dataset_name == 'openwebtext-valid':
|
| 350 |
+
dataset = datasets.load_dataset(
|
| 351 |
+
'openwebtext',
|
| 352 |
+
split='train[-100000:]',
|
| 353 |
+
cache_dir=cache_dir,
|
| 354 |
+
streaming=streaming)
|
| 355 |
+
elif dataset_name == 'scientific_papers_arxiv':
|
| 356 |
+
dataset = datasets.load_dataset(
|
| 357 |
+
'scientific_papers', 'arxiv',
|
| 358 |
+
trust_remote_code=True,
|
| 359 |
+
cache_dir=cache_dir,
|
| 360 |
+
streaming=streaming)
|
| 361 |
+
elif dataset_name == 'scientific_papers_pubmed':
|
| 362 |
+
dataset = datasets.load_dataset(
|
| 363 |
+
'scientific_papers', 'pubmed',
|
| 364 |
+
trust_remote_code=True,
|
| 365 |
+
cache_dir=cache_dir,
|
| 366 |
+
streaming=streaming)
|
| 367 |
+
elif dataset_name == 'ag_news':
|
| 368 |
+
dataset = datasets.load_dataset(
|
| 369 |
+
'ag_news',
|
| 370 |
+
cache_dir=cache_dir,
|
| 371 |
+
streaming=streaming)
|
| 372 |
+
else:
|
| 373 |
+
dataset = datasets.load_dataset(
|
| 374 |
+
dataset_name,
|
| 375 |
+
cache_dir=cache_dir,
|
| 376 |
+
streaming=streaming)
|
| 377 |
+
|
| 378 |
+
if dataset_name in ['lambada', 'openwebtext-train',
|
| 379 |
+
'openwebtext-valid']:
|
| 380 |
+
data = dataset
|
| 381 |
+
else:
|
| 382 |
+
data = dataset[mode]
|
| 383 |
+
|
| 384 |
+
if dataset_name.startswith('wikitext'):
|
| 385 |
+
detokenizer = wt_detokenizer
|
| 386 |
+
elif dataset_name == 'ptb':
|
| 387 |
+
detokenizer = ptb_detokenizer
|
| 388 |
+
elif dataset_name == 'lm1b':
|
| 389 |
+
detokenizer = lm1b_detokenizer
|
| 390 |
+
elif dataset_name == 'lambada':
|
| 391 |
+
detokenizer = lambada_detokenizer
|
| 392 |
+
elif dataset_name.startswith('scientific_papers'):
|
| 393 |
+
detokenizer = scientific_papers_detokenizer
|
| 394 |
+
else:
|
| 395 |
+
detokenizer = None
|
| 396 |
+
|
| 397 |
+
def _apply_detokenizer(detokenizer):
|
| 398 |
+
def detok(text):
|
| 399 |
+
for i, t in enumerate(text, 0):
|
| 400 |
+
text[i] = detokenizer(t)
|
| 401 |
+
return text
|
| 402 |
+
return detok
|
| 403 |
+
|
| 404 |
+
EOS = tokenizer.encode(tokenizer.eos_token)[0]
|
| 405 |
+
BOS = tokenizer.encode(tokenizer.bos_token)[0]
|
| 406 |
+
|
| 407 |
+
def preprocess_and_tokenize(example):
|
| 408 |
+
if dataset_name == 'ptb':
|
| 409 |
+
text = example['sentence']
|
| 410 |
+
elif 'scientific_papers' in dataset_name:
|
| 411 |
+
text = example['article']
|
| 412 |
+
else:
|
| 413 |
+
text = example['text']
|
| 414 |
+
|
| 415 |
+
if detokenizer is not None:
|
| 416 |
+
text = _apply_detokenizer(detokenizer)(text)
|
| 417 |
+
|
| 418 |
+
tokenizer.padding_side = 'right'
|
| 419 |
+
tokenizer.truncation_side = 'right'
|
| 420 |
+
|
| 421 |
+
if wrap:
|
| 422 |
+
tokens = tokenizer(text,
|
| 423 |
+
add_special_tokens=False,
|
| 424 |
+
return_attention_mask=False,
|
| 425 |
+
return_token_type_ids=False)
|
| 426 |
+
tokens = {'input_ids':
|
| 427 |
+
[t + [EOS] for t in tokens['input_ids']]}
|
| 428 |
+
# Still missing BOS, but will be added in group_texts
|
| 429 |
+
else:
|
| 430 |
+
tokens = tokenizer(text,
|
| 431 |
+
max_length=block_size,
|
| 432 |
+
padding='max_length',
|
| 433 |
+
truncation=True,
|
| 434 |
+
add_special_tokens=True,
|
| 435 |
+
return_attention_mask=True,
|
| 436 |
+
return_token_type_ids=True)
|
| 437 |
+
return tokens
|
| 438 |
+
|
| 439 |
+
if streaming:
|
| 440 |
+
tokenized_dataset = data.map(
|
| 441 |
+
preprocess_and_tokenize,
|
| 442 |
+
batched=True,
|
| 443 |
+
desc='Tokenizing')
|
| 444 |
+
else:
|
| 445 |
+
tokenized_dataset = data.map(
|
| 446 |
+
preprocess_and_tokenize,
|
| 447 |
+
batched=True,
|
| 448 |
+
num_proc=num_proc,
|
| 449 |
+
load_from_cache_file=True,
|
| 450 |
+
desc='Tokenizing')
|
| 451 |
+
if dataset_name == 'ptb':
|
| 452 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
| 453 |
+
'sentence')
|
| 454 |
+
elif 'scientific_papers' in dataset_name:
|
| 455 |
+
tokenized_dataset = tokenized_dataset.remove_columns([
|
| 456 |
+
'article', 'abstract', 'section_names'])
|
| 457 |
+
elif dataset_name == 'ag_news':
|
| 458 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
| 459 |
+
['text', 'label'])
|
| 460 |
+
else:
|
| 461 |
+
tokenized_dataset = tokenized_dataset.remove_columns(
|
| 462 |
+
'text')
|
| 463 |
+
|
| 464 |
+
if not wrap:
|
| 465 |
+
tokenized_dataset.save_to_disk(_path)
|
| 466 |
+
return tokenized_dataset.with_format('torch')
|
| 467 |
+
|
| 468 |
+
group_texts = functools.partial(
|
| 469 |
+
_group_texts, block_size=block_size, bos=BOS, eos=EOS)
|
| 470 |
+
if streaming:
|
| 471 |
+
chunked_dataset = tokenized_dataset.map(
|
| 472 |
+
group_texts,
|
| 473 |
+
batched=True,
|
| 474 |
+
desc='Grouping')
|
| 475 |
+
else:
|
| 476 |
+
chunked_dataset = tokenized_dataset.map(
|
| 477 |
+
group_texts,
|
| 478 |
+
batched=True,
|
| 479 |
+
num_proc=num_proc,
|
| 480 |
+
load_from_cache_file=True,
|
| 481 |
+
desc='Grouping')
|
| 482 |
+
chunked_dataset.save_to_disk(_path)
|
| 483 |
+
chunked_dataset = chunked_dataset.with_format('torch')
|
| 484 |
+
return chunked_dataset
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def get_tokenizer(config):
|
| 488 |
+
if config.data.tokenizer_name_or_path == 'text8':
|
| 489 |
+
tokenizer = Text8Tokenizer()
|
| 490 |
+
elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
|
| 491 |
+
tokenizer = transformers.BertTokenizer.\
|
| 492 |
+
from_pretrained('bert-base-uncased')
|
| 493 |
+
else:
|
| 494 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 495 |
+
config.data.tokenizer_name_or_path)
|
| 496 |
+
|
| 497 |
+
if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
|
| 498 |
+
or isinstance(tokenizer, transformers.GPT2Tokenizer)):
|
| 499 |
+
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
|
| 500 |
+
(tokenizer.bos_token, tokenizer.bos_token_id),
|
| 501 |
+
(tokenizer.eos_token, tokenizer.eos_token_id))
|
| 502 |
+
|
| 503 |
+
# For wrapped batches:
|
| 504 |
+
# [BOS] sent1 [EOS] sent2-fragment [EOS]
|
| 505 |
+
# [BOS] sent2-fragment [EOS] sent3 [EOS]
|
| 506 |
+
if tokenizer.bos_token is None:
|
| 507 |
+
if tokenizer.cls_token is None:
|
| 508 |
+
raise AttributeError(
|
| 509 |
+
'Tokenizer must have a bos_token or '
|
| 510 |
+
f'cls_token: {tokenizer}')
|
| 511 |
+
tokenizer.bos_token = tokenizer.cls_token
|
| 512 |
+
if tokenizer.eos_token is None:
|
| 513 |
+
if tokenizer.sep_token is None:
|
| 514 |
+
raise AttributeError(
|
| 515 |
+
'Tokenizer must have a eos_token '
|
| 516 |
+
f'or sep_token: {tokenizer}')
|
| 517 |
+
tokenizer.eos_token = tokenizer.sep_token
|
| 518 |
+
if tokenizer.pad_token is None:
|
| 519 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 520 |
+
|
| 521 |
+
return tokenizer
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def get_dataloaders(config, tokenizer, skip_train=False,
|
| 525 |
+
skip_valid=False, valid_seed=None):
|
| 526 |
+
num_gpus = torch.cuda.device_count()
|
| 527 |
+
assert (config.loader.global_batch_size
|
| 528 |
+
== (config.loader.batch_size
|
| 529 |
+
* config.trainer.num_nodes
|
| 530 |
+
* num_gpus
|
| 531 |
+
* config.trainer.accumulate_grad_batches))
|
| 532 |
+
if config.loader.global_batch_size % (
|
| 533 |
+
num_gpus * config.trainer.accumulate_grad_batches) != 0:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
f'Train Batch Size {config.training.batch_size}'
|
| 536 |
+
f'not divisible by {num_gpus} gpus with accumulation '
|
| 537 |
+
f'{config.trainer.accumulate_grad_batches}.')
|
| 538 |
+
if config.loader.eval_global_batch_size % num_gpus != 0:
|
| 539 |
+
raise ValueError(
|
| 540 |
+
f'Eval Batch Size for {config.eval.batch_size} '
|
| 541 |
+
f'not divisible by {num_gpus}.')
|
| 542 |
+
if skip_train:
|
| 543 |
+
train_set = None
|
| 544 |
+
else:
|
| 545 |
+
train_set = get_dataset(
|
| 546 |
+
config.data.train,
|
| 547 |
+
tokenizer,
|
| 548 |
+
mode='train',
|
| 549 |
+
wrap=config.data.wrap,
|
| 550 |
+
#cache_dir=config.data.cache_dir,
|
| 551 |
+
block_size=config.model.length)
|
| 552 |
+
|
| 553 |
+
if config.data.valid in ['text8', 'lm1b', 'ag_news']:
|
| 554 |
+
validation_split = 'test'
|
| 555 |
+
else:
|
| 556 |
+
validation_split = 'validation'
|
| 557 |
+
if skip_valid:
|
| 558 |
+
valid_set = None
|
| 559 |
+
else:
|
| 560 |
+
valid_set = get_dataset(
|
| 561 |
+
config.data.valid,
|
| 562 |
+
tokenizer,
|
| 563 |
+
wrap=config.data.wrap,
|
| 564 |
+
mode=validation_split,
|
| 565 |
+
#cache_dir=config.data.cache_dir,
|
| 566 |
+
block_size=config.model.length,
|
| 567 |
+
streaming=False)
|
| 568 |
+
|
| 569 |
+
if skip_train:
|
| 570 |
+
train_loader = None
|
| 571 |
+
else:
|
| 572 |
+
train_loader = torch.utils.data.DataLoader(
|
| 573 |
+
train_set,
|
| 574 |
+
batch_size=config.loader.batch_size,
|
| 575 |
+
num_workers=config.loader.num_workers,
|
| 576 |
+
pin_memory=config.loader.pin_memory,
|
| 577 |
+
shuffle=not config.data.streaming,
|
| 578 |
+
persistent_workers=True)
|
| 579 |
+
train_loader.tokenizer = tokenizer
|
| 580 |
+
if skip_valid:
|
| 581 |
+
valid_loader = None
|
| 582 |
+
else:
|
| 583 |
+
if valid_seed is None:
|
| 584 |
+
shuffle_valid = False
|
| 585 |
+
generator = None
|
| 586 |
+
else:
|
| 587 |
+
shuffle_valid = True
|
| 588 |
+
generator = torch.Generator().manual_seed(valid_seed)
|
| 589 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 590 |
+
valid_set,
|
| 591 |
+
batch_size=config.loader.eval_batch_size,
|
| 592 |
+
num_workers=config.loader.num_workers,
|
| 593 |
+
pin_memory=config.loader.pin_memory,
|
| 594 |
+
shuffle=shuffle_valid,
|
| 595 |
+
generator=generator)
|
| 596 |
+
# Will be used in generative perplexity calculation
|
| 597 |
+
valid_loader.tokenizer = tokenizer
|
| 598 |
+
|
| 599 |
+
return train_loader, valid_loader
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
|
| 606 |
+
|
| 607 |
+
def __init__(self, *args, generator=None, **kwargs):
|
| 608 |
+
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
|
| 609 |
+
# which should be reproducible if pl.seed_everything was called beforehand.
|
| 610 |
+
# This means that changing the seed of the experiment will also change the
|
| 611 |
+
# sampling order.
|
| 612 |
+
if generator is None:
|
| 613 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 614 |
+
generator = torch.Generator().manual_seed(seed)
|
| 615 |
+
kwargs.pop('shuffle', None)
|
| 616 |
+
super().__init__(*args, generator=generator, **kwargs)
|
| 617 |
+
self.counter = 0
|
| 618 |
+
self.restarting = False
|
| 619 |
+
|
| 620 |
+
def state_dict(self):
|
| 621 |
+
return {'random_state': self.generator.get_state(),
|
| 622 |
+
'counter': self.counter}
|
| 623 |
+
|
| 624 |
+
def load_state_dict(self, state_dict):
|
| 625 |
+
self.generator.set_state(state_dict.get('random_state'))
|
| 626 |
+
self.counter = state_dict['counter']
|
| 627 |
+
# self.start_counter = self.counter
|
| 628 |
+
self.restarting = True
|
| 629 |
+
|
| 630 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
| 631 |
+
# epoch, and subsequent epoch will have very few batches.
|
| 632 |
+
|
| 633 |
+
def __iter__(self) -> typing.Iterator[int]:
|
| 634 |
+
n = len(self.data_source)
|
| 635 |
+
|
| 636 |
+
self.state = self.generator.get_state()
|
| 637 |
+
indices = torch.randperm(n, generator=self.generator).tolist()
|
| 638 |
+
|
| 639 |
+
if not self.restarting:
|
| 640 |
+
self.counter = 0
|
| 641 |
+
else:
|
| 642 |
+
indices = indices[self.counter:]
|
| 643 |
+
self.restarting = False
|
| 644 |
+
|
| 645 |
+
for index in indices:
|
| 646 |
+
self.counter += 1
|
| 647 |
+
yield index
|
| 648 |
+
|
| 649 |
+
self.counter = 0
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
|
| 653 |
+
|
| 654 |
+
def __init__(self, *args, **kwargs):
|
| 655 |
+
super().__init__(*args, **kwargs)
|
| 656 |
+
self.counter = 0
|
| 657 |
+
self.restarting = False
|
| 658 |
+
|
| 659 |
+
def state_dict(self):
|
| 660 |
+
return {'epoch': self.epoch, 'counter': self.counter}
|
| 661 |
+
|
| 662 |
+
def load_state_dict(self, state_dict):
|
| 663 |
+
self.epoch = state_dict['epoch']
|
| 664 |
+
self.counter = state_dict['counter']
|
| 665 |
+
self.restarting = True
|
| 666 |
+
|
| 667 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
| 668 |
+
# epoch, and subsequent epoch will have very few batches.
|
| 669 |
+
def __iter__(self):
|
| 670 |
+
if self.shuffle:
|
| 671 |
+
# deterministically shuffle based on epoch and seed
|
| 672 |
+
g = torch.Generator()
|
| 673 |
+
g.manual_seed(self.seed + self.epoch)
|
| 674 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
| 675 |
+
else:
|
| 676 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
| 677 |
+
|
| 678 |
+
if not self.drop_last:
|
| 679 |
+
# add extra samples to make it evenly divisible
|
| 680 |
+
padding_size = self.total_size - len(indices)
|
| 681 |
+
if padding_size <= len(indices):
|
| 682 |
+
indices += indices[:padding_size]
|
| 683 |
+
else:
|
| 684 |
+
indices += (indices * math.ceil(
|
| 685 |
+
padding_size / len(indices)))[:padding_size]
|
| 686 |
+
else:
|
| 687 |
+
# remove tail of data to make it evenly divisible.
|
| 688 |
+
indices = indices[:self.total_size]
|
| 689 |
+
assert len(indices) == self.total_size
|
| 690 |
+
|
| 691 |
+
# subsample
|
| 692 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 693 |
+
assert len(indices) == self.num_samples
|
| 694 |
+
|
| 695 |
+
if not self.restarting:
|
| 696 |
+
self.counter = 0
|
| 697 |
+
else:
|
| 698 |
+
indices = indices[self.counter:]
|
| 699 |
+
self.restarting = False
|
| 700 |
+
|
| 701 |
+
for index in indices:
|
| 702 |
+
self.counter += 1
|
| 703 |
+
yield index
|
| 704 |
+
|
| 705 |
+
self.counter = 0
|
| 706 |
+
|
| 707 |
+
from torch.utils.data import Dataset, DataLoader
|
| 708 |
+
import lightning.pytorch as pl
|
| 709 |
+
from functools import partial
|
| 710 |
+
import sys
|
| 711 |
+
|
| 712 |
+
class CustomDataset(torch.utils.data.Dataset):
|
| 713 |
+
def __init__(self, dataset, indices):
|
| 714 |
+
self.dataset = dataset
|
| 715 |
+
self.indices = indices
|
| 716 |
+
|
| 717 |
+
def __len__(self):
|
| 718 |
+
return len(self.indices)
|
| 719 |
+
|
| 720 |
+
def __getitem__(self, idx):
|
| 721 |
+
actual_idx = int(self.indices[idx])
|
| 722 |
+
item = self.dataset[actual_idx]
|
| 723 |
+
return item
|
| 724 |
+
|
| 725 |
+
def membrane_collate_fn(batch, tokenizer):
|
| 726 |
+
"""Custom data collator that masks TM/soluble residues for focused training"""
|
| 727 |
+
MAX_LENGTH = 1024
|
| 728 |
+
sequences = [item['Sequence'].upper() for item in batch]
|
| 729 |
+
|
| 730 |
+
masks = []
|
| 731 |
+
for item in batch:
|
| 732 |
+
if item["Label"] == 0:
|
| 733 |
+
mask = [1 if i.isupper() else 0 for i in item["Sequence"]]
|
| 734 |
+
else:
|
| 735 |
+
mask = [0 if i.isupper() else 1 for i in item["Sequence"]]
|
| 736 |
+
mask = [1] + mask
|
| 737 |
+
if len(mask) > MAX_LENGTH: # Truncate
|
| 738 |
+
mask = mask[:MAX_LENGTH]
|
| 739 |
+
elif len(mask) < MAX_LENGTH: # Pad
|
| 740 |
+
mask += [1] * (MAX_LENGTH - len(mask))
|
| 741 |
+
|
| 742 |
+
masks.append(torch.as_tensor(mask))
|
| 743 |
+
|
| 744 |
+
mask_t = torch.stack(masks, dim=0)
|
| 745 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=MAX_LENGTH)
|
| 746 |
+
|
| 747 |
+
return {
|
| 748 |
+
'input_ids': tokens['input_ids'],
|
| 749 |
+
'attention_mask': tokens['attention_mask'],
|
| 750 |
+
'mask': mask_t
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
def wrap_collate_fn(batch, tokenizer):
|
| 754 |
+
"""Standard data collator that wraps sequences over padding them"""
|
| 755 |
+
# Define sequence size
|
| 756 |
+
chunk_size = 1024
|
| 757 |
+
eos_placeholder = "k"
|
| 758 |
+
eos = "<eos>"
|
| 759 |
+
|
| 760 |
+
# Wrap sequences by collecting and splitting them into chunks
|
| 761 |
+
# From MDLM paper: insert <eos> at start/end of chunks and in between sequences
|
| 762 |
+
sequences = eos_placeholder.join([item['Sequence'].upper() for item in batch])
|
| 763 |
+
sequences = eos_placeholder + sequences + eos_placeholder
|
| 764 |
+
wrapped_sequences = [sequences[i:i+chunk_size] for i in range(0, len(sequences), chunk_size)]
|
| 765 |
+
for idx, seq in enumerate(wrapped_sequences):
|
| 766 |
+
wrapped_sequences[idx] = seq.replace(eos_placeholder, eos)
|
| 767 |
+
|
| 768 |
+
# Tokenize for input ids and attention masks
|
| 769 |
+
tokens = tokenizer(wrapped_sequences, return_tensors='pt', padding=True)
|
| 770 |
+
|
| 771 |
+
return {
|
| 772 |
+
"input_ids": tokens['input_ids'],
|
| 773 |
+
"attention_mask": tokens['attention_mask']
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def collate_fn(batch, tokenizer):
|
| 779 |
+
"""Standard data collator that truncates/pad sequences based on max_length"""
|
| 780 |
+
sequences = [item['Sequence'].upper() for item in batch]
|
| 781 |
+
max_len = max([len(seq) for seq in sequences])
|
| 782 |
+
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
|
| 783 |
+
|
| 784 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding='max_length', truncation=True, max_length=1024)
|
| 785 |
+
|
| 786 |
+
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
|
| 787 |
+
|
| 788 |
+
return {
|
| 789 |
+
'input_ids': tokens['input_ids'],
|
| 790 |
+
'attention_mask': tokens['attention_mask']
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 794 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size: int=8, collate_fn=collate_fn):
|
| 795 |
+
super().__init__()
|
| 796 |
+
self.train_dataset = train_dataset
|
| 797 |
+
self.val_dataset = val_dataset
|
| 798 |
+
self.test_dataset = test_dataset
|
| 799 |
+
self.batch_size = batch_size
|
| 800 |
+
self.tokenizer = tokenizer
|
| 801 |
+
self.collate_fn = collate_fn
|
| 802 |
+
|
| 803 |
+
def train_dataloader(self):
|
| 804 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size,
|
| 805 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 806 |
+
num_workers=8, pin_memory=True)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
def val_dataloader(self):
|
| 810 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size,
|
| 811 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 812 |
+
num_workers=8, pin_memory=True)
|
| 813 |
+
|
| 814 |
+
def test_dataloader(self):
|
| 815 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size,
|
| 816 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 817 |
+
num_workers=8, pin_memory=True)
|
| 818 |
+
|
| 819 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import fsspec
|
| 11 |
+
import lightning
|
| 12 |
+
import torch
|
| 13 |
+
from timm.scheduler import CosineLRScheduler
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def fsspec_exists(filename):
|
| 17 |
+
"""Check if a file exists using fsspec."""
|
| 18 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 19 |
+
return fs.exists(filename)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fsspec_listdir(dirname):
|
| 23 |
+
"""Listdir in manner compatible with fsspec."""
|
| 24 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 25 |
+
return fs.ls(dirname)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 29 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 31 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def print_nans(tensor, name):
|
| 35 |
+
if torch.isnan(tensor).any():
|
| 36 |
+
print(name, tensor)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CosineDecayWarmupLRScheduler(
|
| 40 |
+
CosineLRScheduler,
|
| 41 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 42 |
+
"""Wrap timm.scheduler.CosineLRScheduler
|
| 43 |
+
Enables calling scheduler.step() without passing in epoch.
|
| 44 |
+
Supports resuming as well.
|
| 45 |
+
Adapted from:
|
| 46 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
self._last_epoch = -1
|
| 52 |
+
self.step(epoch=0)
|
| 53 |
+
|
| 54 |
+
def step(self, epoch=None):
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self._last_epoch += 1
|
| 57 |
+
else:
|
| 58 |
+
self._last_epoch = epoch
|
| 59 |
+
# We call either step or step_update, depending on
|
| 60 |
+
# whether we're using the scheduler every epoch or every
|
| 61 |
+
# step.
|
| 62 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 63 |
+
# meant for each epoch), and if we set scheduler
|
| 64 |
+
# interval to "step", then the learning rate update will
|
| 65 |
+
# be wrong.
|
| 66 |
+
if self.t_in_epochs:
|
| 67 |
+
super().step(epoch=self._last_epoch)
|
| 68 |
+
else:
|
| 69 |
+
super().step_update(num_updates=self._last_epoch)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LoggingContext:
|
| 73 |
+
"""Context manager for selective logging."""
|
| 74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 75 |
+
self.logger = logger
|
| 76 |
+
self.level = level
|
| 77 |
+
self.handler = handler
|
| 78 |
+
self.close = close
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
if self.level is not None:
|
| 82 |
+
self.old_level = self.logger.level
|
| 83 |
+
self.logger.setLevel(self.level)
|
| 84 |
+
if self.handler:
|
| 85 |
+
self.logger.addHandler(self.handler)
|
| 86 |
+
|
| 87 |
+
def __exit__(self, et, ev, tb):
|
| 88 |
+
if self.level is not None:
|
| 89 |
+
self.logger.setLevel(self.old_level)
|
| 90 |
+
if self.handler:
|
| 91 |
+
self.logger.removeHandler(self.handler)
|
| 92 |
+
if self.handler and self.close:
|
| 93 |
+
self.handler.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 98 |
+
|
| 99 |
+
logger = logging.getLogger(name)
|
| 100 |
+
logger.setLevel(level)
|
| 101 |
+
|
| 102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 105 |
+
'exception', 'fatal', 'critical'):
|
| 106 |
+
setattr(logger,
|
| 107 |
+
level,
|
| 108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 109 |
+
getattr(logger, level)))
|
| 110 |
+
|
| 111 |
+
return logger
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Sampler:
|
| 115 |
+
def __init__(self, shape):
|
| 116 |
+
self.shape = shape
|
| 117 |
+
|
| 118 |
+
def _sampling_noise(self):
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
def _hard_sample(self, logits):
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
def _soft_sample(self, logits):
|
| 125 |
+
return 0
|
| 126 |
+
|
| 127 |
+
def sample(self, logits):
|
| 128 |
+
noise = self._sampling_noise()
|
| 129 |
+
noise = noise[: logits.shape[0], :]
|
| 130 |
+
logits = logits + noise.to(
|
| 131 |
+
dtype=logits.dtype, device=logits.device)
|
| 132 |
+
hard_sample = self._hard_sample(logits)
|
| 133 |
+
soft_sample = self._soft_sample(logits)
|
| 134 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TopKSampler(Sampler):
|
| 138 |
+
def __init__(self, k, shape, gamma_tau=1.0):
|
| 139 |
+
super().__init__(shape)
|
| 140 |
+
self.k = k
|
| 141 |
+
self.gamma_tau = gamma_tau
|
| 142 |
+
self.num_betas = 10
|
| 143 |
+
self.sampler = torch.distributions.gamma.Gamma(
|
| 144 |
+
1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
|
| 145 |
+
|
| 146 |
+
def _sampling_noise(self):
|
| 147 |
+
noise = self.sampler.sample()
|
| 148 |
+
beta = self.k / torch.arange(1, self.num_betas + 1, 1,
|
| 149 |
+
dtype=torch.float32)
|
| 150 |
+
beta = beta[:, None, None]
|
| 151 |
+
assert beta.ndim == noise.ndim
|
| 152 |
+
s = noise / beta
|
| 153 |
+
s = torch.sum(s, axis=0)
|
| 154 |
+
s = s - math.log(10.0)
|
| 155 |
+
s = self.gamma_tau * (s / self.k)
|
| 156 |
+
return s
|
| 157 |
+
|
| 158 |
+
def _hard_sample(self, logits):
|
| 159 |
+
assert logits.ndim == 2
|
| 160 |
+
thresholds, _ = torch.sort(logits, dim=-1)
|
| 161 |
+
thresholds = thresholds[:, - self.k][:, None]
|
| 162 |
+
return (logits >= thresholds).type(logits.dtype)
|
| 163 |
+
|
| 164 |
+
def _soft_sample(self, logits):
|
| 165 |
+
soft_top_k = logits - torch.mean(logits, dim=-1,
|
| 166 |
+
keepdim=True)
|
| 167 |
+
return soft_top_k / torch.norm(soft_top_k, dim=-1,
|
| 168 |
+
keepdim=True)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class DeterministicTopK(TopKSampler):
|
| 172 |
+
def __init__(self, k):
|
| 173 |
+
super().__init__(k, shape=(1, 1))
|
| 174 |
+
|
| 175 |
+
def _sampling_noise(self):
|
| 176 |
+
return 0
|
| 177 |
+
|
| 178 |
+
def discreize(self, x):
|
| 179 |
+
hard_sample = self._hard_sample(x)
|
| 180 |
+
soft_sample = self._soft_sample(x)
|
| 181 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 182 |
+
|
| 183 |
+
class GumbelSampler(Sampler):
|
| 184 |
+
|
| 185 |
+
def __init__(self, shape, temperature=1.0):
|
| 186 |
+
super().__init__(shape)
|
| 187 |
+
self.temperature = temperature
|
| 188 |
+
|
| 189 |
+
def _sampling_noise(self):
|
| 190 |
+
return - (1e-10 - (
|
| 191 |
+
torch.rand(* self.shape) + 1e-10).log()).log()
|
| 192 |
+
|
| 193 |
+
def _hard_sample(self, logits):
|
| 194 |
+
assert logits.ndim == 2
|
| 195 |
+
indices = torch.argmax(logits, dim=-1)
|
| 196 |
+
zeros = logits * 0
|
| 197 |
+
ones = torch.ones_like(logits[:, :, :1])
|
| 198 |
+
return torch.scatter(zeros, -1, indices[:, :, None],
|
| 199 |
+
ones)
|
| 200 |
+
|
| 201 |
+
def _soft_sample(self, logits):
|
| 202 |
+
return torch.nn.functional.softmax(
|
| 203 |
+
logits / self.temperature, dim=-1)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class BinarySampler(GumbelSampler):
|
| 207 |
+
|
| 208 |
+
def sample(self, probs):
|
| 209 |
+
# TODO(subhamsahoo): use the temperature parameter.
|
| 210 |
+
pos_noise = self._sampling_noise().to(
|
| 211 |
+
dtype=probs.dtype, device=probs.device)
|
| 212 |
+
neg_noise = self._sampling_noise().to(
|
| 213 |
+
dtype=probs.dtype, device=probs.device)
|
| 214 |
+
del_noise_exp = (neg_noise - pos_noise).exp()
|
| 215 |
+
hard_sample = (probs * (1 + del_noise_exp)
|
| 216 |
+
> 1).to(probs.dtype)
|
| 217 |
+
soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
|
| 218 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class GaussianSampler:
|
| 222 |
+
def __init__(self):
|
| 223 |
+
self.softplus = torch.nn.Softplus()
|
| 224 |
+
|
| 225 |
+
def sample(self, x):
|
| 226 |
+
assert x.ndim == 2
|
| 227 |
+
n = x.shape[-1] // 2
|
| 228 |
+
mu = x[:, :n]
|
| 229 |
+
sigma = self.softplus(x[:, n:]).sqrt()
|
| 230 |
+
return mu + sigma * torch.randn_like(mu)
|