Image-to-Image
MedVAE
Ashwin Kumar commited on
Commit
3e91af9
·
1 Parent(s): 6c9c3da

model_weights and readme

Browse files
README.md CHANGED
@@ -1,3 +1,42 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # Med-VAE
6
+
7
+ Med-VAE is a family of six large-scale, generalizable 2D and 3D variational autoencoders (VAEs) designed for medical imaging. It is trained on over one million medical images across multiple anatomical regions and modalities. Med-VAE autoencoders encode medical images as downsized latent representations and decode latent representations back to high-resolution images. Across diverse tasks obtained from 20 medical image datasets, we demonstrate that utilizing MedVAE latent representations in place of high-resolution images when training downstream models can lead to efficiency benefits (up to 70x improvement in throughput) while simultaneously preserving clinically-relevant features.
8
+
9
+ [💻 Github](https://github.com/StanfordMIMI/MedVAE)
10
+
11
+ ## Model Description
12
+ | Total Compression Factor | Channels | Dimensions | Modalities | Anatomies | Config File | Model File |
13
+ |----------|----------|----------|----------|----------|----------|----------|
14
+ | 16 | 1 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_4x1.yaml ](model_weights/medvae_4x1.yaml)| [vae_4x_1c_2D.ckpt](model_weights/vae_4x_1c_2D.ckpt)
15
+ | 16 | 3 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_4x3.yaml](model_weights/medvae_4x3.yaml) | [vae_4x_3c_2D.ckpt](model_weights/vae_4x_3c_2D.ckpt)
16
+ | 64 | 1 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_8x1.yaml](model_weights/medvae_8x1.yaml) | [vae_8x_1c_2D.ckpt](model_weights/vae_8x_1c_2D.ckpt)
17
+ | 64 | 3 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_8x4.yaml](model_weights/medvae_8x4.yaml) | [vae_8x_4c_2D.ckpt](model_weights/vae_8x_4c_2D.ckpt)
18
+ | 64 | 1 | 3D | MRI, CT | Whole-Body | [medvae_4x1.yaml ](model_weights/medvae_4x1.yaml) | [vae_4x_1c_3D.ckpt](model_weights/vae_4x_1c_3D.ckpt)
19
+ | 512 | 1 | 3D | MRI, CT | Whole-Body | [medvae_8x1.yaml](model_weights/medvae_8x1.yaml) | [vae_8x_1c_3D.ckpt](model_weights/vae_8x_1c_3D.ckpt)
20
+
21
+ Note: Model weights and checkpoints are located in the `model_weights` folder.
22
+
23
+ ## Usage Instructions
24
+
25
+
26
+
27
+ ## Citation
28
+ If you use Med-VAE, please cite the original paper:
29
+
30
+ ```bibtex
31
+ @article{varma2025medvae,
32
+ title = {Med-VAE: --},
33
+ author = {Maya Varma, Ashwin Kumar, Rogier van der Sluijs, Sophie Ostmeier, Louis Blankemeier, Pierre Chambon, Christian Bluethgen, Jip Prince, Curtis Langlotz, Akshay Chaudhari},
34
+ year = {2025},
35
+ publisher = {Arxiv},
36
+ journal = {Arvix},
37
+ howpublished = {TODO}
38
+ }
39
+ ```
40
+
41
+ For questions, please place a Github Issues message.
42
+
model_weights/medvae_4x1.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ embed_dim: 1
2
+
3
+ ddconfig:
4
+ double_z: True
5
+ z_channels: 1
6
+ resolution: 512
7
+ in_channels: 1
8
+ out_ch: 1
9
+ ch: 128
10
+ ch_mult: [1,2,4]
11
+ num_res_blocks: 2
12
+ attn_resolutions: []
13
+ dropout: 0.0
model_weights/medvae_4x3.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 3
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 3
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+
27
+ data:
28
+ target: main.DataModuleFromConfig
29
+ params:
30
+ batch_size: 12
31
+ wrap: True
32
+ train:
33
+ target: ldm.data.imagenet.ImageNetSRTrain
34
+ params:
35
+ size: 256
36
+ degradation: pil_nearest
37
+ validation:
38
+ target: ldm.data.imagenet.ImageNetSRValidation
39
+ params:
40
+ size: 256
41
+ degradation: pil_nearest
42
+
43
+ lightning:
44
+ callbacks:
45
+ image_logger:
46
+ target: main.ImageLogger
47
+ params:
48
+ batch_frequency: 1000
49
+ max_images: 8
50
+ increase_log_steps: True
51
+
52
+ trainer:
53
+ benchmark: True
54
+ accumulate_grad_batches: 2
model_weights/medvae_8x1.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ embed_dim: 1
2
+
3
+ ddconfig:
4
+ double_z: True
5
+ z_channels: 1
6
+ resolution: 512
7
+ in_channels: 1
8
+ out_ch: 1
9
+ ch: 128
10
+ ch_mult: [1,2,4,4]
11
+ num_res_blocks: 2
12
+ attn_resolutions: []
13
+ dropout: 0.0
model_weights/medvae_8x4.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 4
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 4
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+ data:
27
+ target: main.DataModuleFromConfig
28
+ params:
29
+ batch_size: 12
30
+ wrap: True
31
+ train:
32
+ target: ldm.data.imagenet.ImageNetSRTrain
33
+ params:
34
+ size: 256
35
+ degradation: pil_nearest
36
+ validation:
37
+ target: ldm.data.imagenet.ImageNetSRValidation
38
+ params:
39
+ size: 256
40
+ degradation: pil_nearest
41
+
42
+ lightning:
43
+ callbacks:
44
+ image_logger:
45
+ target: main.ImageLogger
46
+ params:
47
+ batch_frequency: 1000
48
+ max_images: 8
49
+ increase_log_steps: True
50
+
51
+ trainer:
52
+ benchmark: True
53
+ accumulate_grad_batches: 2
model_weights/vae_4x_1c_2D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e4bd8931238a6c52acb3d826025f5bcfa284aa9f98cc5455505f131d690c1ba
3
+ size 221345538
model_weights/vae_4x_1c_3D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c780a850d5ed20303c37a84be6022c8ebca2236509d673e3059bfcec75ce383a
3
+ size 644085658
model_weights/vae_4x_3c_2D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651dfc0792fc61d004cb795b44b52dcb3c3523321776f0c84b6890561c2e5778
3
+ size 223784534
model_weights/vae_8x_1c_2D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d99374ce8a6fcaf4a5b0117dd4396b520e273e5c92660772a710421337ddc52
3
+ size 334673798
model_weights/vae_8x_1c_3D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5416db01e6557316630a68c3367403c86100fc9726886fa2ab24595a89f0a98d
3
+ size 983906794
model_weights/vae_8x_4c_2D.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9457e018edd267cfa526878730f5c04699f9c6a9a646a8d51d89aa9d7ccb99f8
3
+ size 337998486