File size: 7,197 Bytes
8d56d36 a2abddc 8d56d36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
---
library_name: diffusers
pipeline_tag: text-to-image
---
## Model Details
### Model Description
This model is fine-tuned from [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen.
- **Developed by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Shared by:** [Raman Dutt](https://twitter.com/RamanDutt4)
- **Model type:** [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning]
- **Finetuned from model:** [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
### Model Sources
- **Paper:** [Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity](https://arxiv.org/abs/2305.08252)
- **Demo:** [MIMIC-SD-PEFT-Demo](https://huggingface.co/spaces/raman07/MIMIC-SD-Demo-Memory-Optimized?logs=container)
## Direct Use
This model can be directly used to generate realistic medical images from text prompts.
## How to Get Started with the Model
```python
import os
from safetensors.torch import load_file
from diffusers.pipelines import StableDiffusionPipeline
#### Defining loading function
def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
print(pretrained_model_name_or_path)
config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
state_dict = original_model.state_dict()
with accelerate.init_empty_weights():
model = UNet2DConditionModelForSVDiff.from_config(config)
# load pre-trained weights
param_device = "cpu"
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
state_dict.update(spectral_shifts_weights)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
if spectral_shifts_ckpt:
if os.path.isdir(spectral_shifts_ckpt):
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
elif not os.path.exists(spectral_shifts_ckpt):
# download from hub
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
assert os.path.exists(spectral_shifts_ckpt)
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
# spectral_shifts_weights[key] = f.get_tensor(key)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
else:
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
print(f"Resumed from {spectral_shifts_ckpt}")
if "torch_dtype"in kwargs:
model = model.to(kwargs["torch_dtype"])
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
del original_model
torch.cuda.empty_cache()
return model
pipe.unet = load_unet_for_svdiff(
"runwayml/stable-diffusion-v1-5",
spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
subfolder="unet",
)
for module in pipe.unet.modules():
if hasattr(module, "perform_svd"):
module.perform_svd()
# Load the adapted U-Net
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.to('cuda:0')
# Generate images with text prompts
TEXT_PROMPT = "No acute cardiopulmonary abnormality."
GUIDANCE_SCALE = 4
INFERENCE_STEPS = 75
result_image = pipe(
prompt=TEXT_PROMPT,
height=224,
width=224,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=INFERENCE_STEPS,
)
result_pil_image = result_image["images"][0]
```
## Training Details
### Training Data
This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset.
### Training Procedure
The training procedure has been described in detail in Section 4.3 of this [paper](https://arxiv.org/abs/2305.08252).
#### Metrics
This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset.
### Results
| Fine-Tuning Strategy | FID Score |
|------------------------|-----------|
| Full FT | 58.74 |
| Attention | 52.41 |
| Bias | 20.81 |
| Norm | 29.84 |
| Bias+Norm+Attention | 35.93 |
| LoRA | 439.65 |
| SV-Diff | 23.59 |
| DiffFit | 42.50 |
## Environmental Impact
Using Parameter-Efficient Fine-Tuning potentially causes **lesser** harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements.
## Citation
**BibTeX:**
@article{dutt2023parameter,
title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity},
author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy},
journal={arXiv preprint arXiv:2305.08252},
year={2023}
}
**APA:**
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252.
## Model Card Authors
Raman Dutt
[Twitter](https://twitter.com/RamanDutt4)
[LinkedIn](https://www.linkedin.com/in/raman-dutt/)
[Email](mailto:[email protected])
## References
1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023). |