File size: 5,215 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# @package _global_

defaults:
  - /model: small

model:
  downscale_ratio: 16
  image_vocab_size: 8192
  vae_type: magvit
  use_custom_vae_ckpt: null
  custom_vae_name: null
  img_length: 256
  txt_length: 128
  image_model: true
  text_model: true
  unified_model: true
  image_model_fid_eval: false
  force_argmax_valid_indices: true
  use_pretrained_img_emb: false
  codebook_embed_dim: 256
  qk_norm: true
  norm_type: rms
  sandwich_normalization: true
  zero_linear_init: false
  modality_embed: true
  rope_2d: false
  use_spda_attn: true
  force_optimized_native_attn: true
  freeze_txt_emb: false
  add_labels: null
  txt_dropout: null
  text_vocab_size: 32001
  use_flex_attention: true
  flex_attention_txt_masking_prob: 0.1
  flex_attention_img_masking_prob: 0.1
  linear_factor: 1
data:
  train: combined_tokens
  valid: ${.train}
  n_duplicate_train: null
  wrap: true
  streaming: false
  precache: false
  tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
  resolution: 256
  block_size: 128
  n_val_samples: null
  unpaired: false
  n_duplicate_val: null
  save_train_dataloader: true
  save_validation_dataloader: true
  iterable: false
  webdataset_iterable: false
  webdataset_indexed: false
  dataset_type: null
  tokens_flip_collate: false
  n_train_samples: null
  raw_data_dir: null
  tokenizers_parallelism: false
  token_data_dir: null
  force_disable_shuffle: false
  keep_tensordict_on_disk: true
  use_custom_tensordict_collate: true
  force_mp_spawn: false
  enable_cuda_in_tensordict_collate: false
  use_weighted_tensordict_sampler: true
  fraction_txt_data: 0.0
  data_dir_train:
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
    weight: -1
    name: datacomp1b_8_magvit_train
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
    weight: -1
    name: cc12m_tokens_train_256
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
    weight: -1
    name: HPDv2_image_reward_v1_v2_v3_magvit
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
    weight: -1
    name: pick_score_sac_prompts_v1_v2_v3_magvit
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
    weight: -1
    name: datacomp1b_0_1_6_magvit
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
    weight: -1
    name: laion400m_magvit_part_0
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
    weight: -1
    name: laion400m_magvit_part_1
  data_dir_val:
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
    weight: 1
    name: datacomp1b_8_magvit_val
  - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
    weight: 1
    name: cc12m_tokens_val_256
  tokenize_vqvae_in_dataloader: false
  val:
    .train: null
  use_token_dataset: true
  image_dataset: tglcourse/lsun_church_train
  image_data_train: null
  image_data_val: null
  keep_hf_dataset_in_memory: true
  allow_label: false
  disable_text_modality: true
  force_raw_train_images: false
  aggressive_aug: true
  allow_aug_vqvae_dataloader: true
  move_tensordict_to_shm: false
  force_full_attention_mask: false
eval:
  generate_samples: false
  compute_generative_perplexity: false
  log_every_n_evals: 10
  log_every_n_fid: 20
  limit_val_batches_manual: 16
  perplexity_batch_size: ${loader.eval_batch_size}
  num_masking_viz_batches: -1
  max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
  cfg: null
  class_conditional_fid: false
  force_cfg_value: true
  split_cfg_batches: true
  fid_mode: clean
  clean_fid_precomputed_name: lsun_church
  clean_fid_precomputed_split: trainfull
  clean_fid_precomputed_res: 256
trainer:
  log_every_n_steps: 10
  val_check_interval: 1000
  custom_ddp_bf16: true
  scale_lr_by_batch_size: false
  limit_val_batches: 16
  use_gradient_checkpointing: false
  log_seperate_modal_losses: true
  softmin_snr: 5
  text_loss_weight: 1.0
  img_loss_weight: null
  low_precision_loss: false
  compile: false
  multimodal_batches: true
  compile_fullgraph: false
  log_grad_norm_every_n_steps: 10
  mask_entire_modality: 0.1
  force_shift_image_batches: false
  ckpt_steps: 10000
  ckpt_every_n_minutes: -1
  ignore_text_in_unified: false
  disable_all_eval_generation: false
  eval_on_start: false
  ckpt_model_only: false
  ema: 0.0
  use_custom_ema: false
  log_flops: false
  disable_distributed_torchmetrics: true
  restart_on_failure: true
  force_null_sigma: true
  allow_null_sigma: true
  compile_flag_pos_emb: true
  add_label: false
  first_token_dropout: null
  force_shift_raw_image_batches: true
  txt_dropout: 0.1
  disable_ddp_optimizer: true
optim:
  lr: 0.0003
  weight_decay: 0.05
loader:
  batch_size: 64
  eval_batch_size: ${loader.batch_size}
  num_workers: 1
  desired_global_batch_size: 512
  persistent_workers: true
  pin_memory: true
  num_eval_workers: 1
sampling:
  steps: ${model.length}
  num_sample_batches: 2
  max_sampling_steps: ${model.length}
wandb:
  mode: online
lr_scheduler:
  num_warmup_steps: 5000
checkpointing:
  checkpoints_total_limit: 4