|
--- |
|
library_name: diffusers |
|
pipeline_tag: text-to-image |
|
--- |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
This model is fine-tuned from [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen. |
|
|
|
- **Developed by:** [Raman Dutt](https://twitter.com/RamanDutt4) |
|
- **Shared by:** [Raman Dutt](https://twitter.com/RamanDutt4) |
|
- **Model type:** [Stable Diffusion fine-tuned using Parameter-Efficient Fine-Tuning] |
|
- **Finetuned from model:** [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) |
|
|
|
### Model Sources |
|
|
|
|
|
- **Paper:** [Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity](https://arxiv.org/abs/2305.08252) |
|
- **Demo:** [MIMIC-SD-PEFT-Demo](https://huggingface.co/spaces/raman07/MIMIC-SD-Demo-Memory-Optimized?logs=container) |
|
|
|
## Direct Use |
|
|
|
This model can be directly used to generate realistic medical images from text prompts. |
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
```python |
|
import os |
|
from safetensors.torch import load_file |
|
from diffusers.pipelines import StableDiffusionPipeline |
|
|
|
|
|
#### Defining loading function |
|
|
|
def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs): |
|
|
|
print(pretrained_model_name_or_path) |
|
config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs) |
|
original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
state_dict = original_model.state_dict() |
|
with accelerate.init_empty_weights(): |
|
model = UNet2DConditionModelForSVDiff.from_config(config) |
|
# load pre-trained weights |
|
param_device = "cpu" |
|
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None |
|
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n} |
|
state_dict.update(spectral_shifts_weights) |
|
# move the params from meta device to cpu |
|
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
|
if len(missing_keys) > 0: |
|
raise ValueError( |
|
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are" |
|
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" |
|
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" |
|
" those weights or else make sure your checkpoint file is correct." |
|
) |
|
|
|
for param_name, param in state_dict.items(): |
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) |
|
if accepts_dtype: |
|
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) |
|
else: |
|
set_module_tensor_to_device(model, param_name, param_device, value=param) |
|
|
|
if spectral_shifts_ckpt: |
|
if os.path.isdir(spectral_shifts_ckpt): |
|
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors") |
|
elif not os.path.exists(spectral_shifts_ckpt): |
|
# download from hub |
|
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs |
|
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs) |
|
assert os.path.exists(spectral_shifts_ckpt) |
|
|
|
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
# spectral_shifts_weights[key] = f.get_tensor(key) |
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) |
|
if accepts_dtype: |
|
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype) |
|
else: |
|
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key)) |
|
print(f"Resumed from {spectral_shifts_ckpt}") |
|
if "torch_dtype"in kwargs: |
|
model = model.to(kwargs["torch_dtype"]) |
|
model.register_to_config(_name_or_path=pretrained_model_name_or_path) |
|
# Set model in evaluation mode to deactivate DropOut modules by default |
|
model.eval() |
|
del original_model |
|
torch.cuda.empty_cache() |
|
return model |
|
|
|
pipe.unet = load_unet_for_svdiff( |
|
"runwayml/stable-diffusion-v1-5", |
|
spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"), |
|
subfolder="unet", |
|
) |
|
for module in pipe.unet.modules(): |
|
if hasattr(module, "perform_svd"): |
|
module.perform_svd() |
|
|
|
# Load the adapted U-Net |
|
pipe.unet.load_state_dict(state_dict, strict=False) |
|
pipe.to('cuda:0') |
|
|
|
# Generate images with text prompts |
|
|
|
TEXT_PROMPT = "No acute cardiopulmonary abnormality." |
|
GUIDANCE_SCALE = 4 |
|
INFERENCE_STEPS = 75 |
|
|
|
result_image = pipe( |
|
prompt=TEXT_PROMPT, |
|
height=224, |
|
width=224, |
|
guidance_scale=GUIDANCE_SCALE, |
|
num_inference_steps=INFERENCE_STEPS, |
|
) |
|
|
|
result_pil_image = result_image["images"][0] |
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset. |
|
|
|
### Training Procedure |
|
|
|
The training procedure has been described in detail in Section 4.3 of this [paper](https://arxiv.org/abs/2305.08252). |
|
|
|
#### Metrics |
|
|
|
This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset. |
|
|
|
### Results |
|
|
|
| Fine-Tuning Strategy | FID Score | |
|
|------------------------|-----------| |
|
| Full FT | 58.74 | |
|
| Attention | 52.41 | |
|
| Bias | 20.81 | |
|
| Norm | 29.84 | |
|
| Bias+Norm+Attention | 35.93 | |
|
| LoRA | 439.65 | |
|
| SV-Diff | 23.59 | |
|
| DiffFit | 42.50 | |
|
|
|
|
|
## Environmental Impact |
|
|
|
Using Parameter-Efficient Fine-Tuning potentially causes **lesser** harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements. |
|
|
|
## Citation |
|
|
|
|
|
**BibTeX:** |
|
|
|
@article{dutt2023parameter, |
|
title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity}, |
|
author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy}, |
|
journal={arXiv preprint arXiv:2305.08252}, |
|
year={2023} |
|
} |
|
|
|
**APA:** |
|
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252. |
|
|
|
## Model Card Authors |
|
|
|
Raman Dutt |
|
[Twitter](https://twitter.com/RamanDutt4) |
|
[LinkedIn](https://www.linkedin.com/in/raman-dutt/) |
|
[Email](mailto:[email protected]) |
|
|
|
## References |
|
|
|
1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023). |