|
# Model Use |
|
|
|
This model is a class-conditional DDPM that can generate ultrasound images. It is based on the diffusion model from the [Towards Realistic Ultrasound Fetal Brain Imaging Synthesis](https://arxiv.org/abs/2304.03941) paper. The dataset used to train the model is the [FETAL_PLANES_DB dataset](https://zenodo.org/record/3904280). The classes that can be generated and the associated integer labels are: Fetal abdomen (0), Fetal brain (1), Fetal femur (2), Fetal thorax (3), Maternal cervix (4), and Other (5). When generating images, simply provide the label of your chosen class as an argument to the UNet. |
|
|
|
Below, you will find code that allows you to load this model, generate an image, and display it: |
|
|
|
```python |
|
# !pip install --upgrade diffusers transformers accelerate scipy ftfy safetensors |
|
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
# Are we using a GPU or CPU? |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
# load model and scheduler |
|
model_id = "harveymannering/xfetus-ddpm-v2" |
|
ddpm = DDPMPipeline.from_pretrained(model_id) |
|
ddpm.to(device) |
|
|
|
# Generate a single image |
|
x = torch.randn(1, 3, 128, 128).to(device) # noise |
|
for i, t in enumerate(ddpm.scheduler.timesteps): |
|
model_input = ddpm.scheduler.scale_model_input(x, t) |
|
with torch.no_grad(): |
|
# Conditiong on the 'Fetal brain' class (with index 1) |
|
class_label = torch.ones(1, dtype=torch.int64) |
|
noise_pred = ddpm.unet(model_input, t, class_label.to(device))["sample"] |
|
x = ddpm.scheduler.step(noise_pred, t, x).prev_sample |
|
|
|
# Display image |
|
plt.imshow(np.transpose(x[0].cpu().detach().numpy(), (1,2,0)) + 0.5) |
|
``` |
|
|
|
# Example Outputs |
|
The figure below includes examples of both real and synthetic images. The following preprocessing and augmentation steps were applied to all training images: |
|
1. Random Horizontal Flip |
|
2. Random Rotation (±45°) |
|
3. Resize to 128×128 using Bicubic Interpolation |
|
|
|
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/uxDp-0svPAp2dCmTK36rf.png"> |
|
|
|
# Training Loss |
|
|
|
The baseline model was trained exclusively on images from the 'Voluson E6' machine. Training and validation losses are presented below. Checkpoints were saved every 50 epochs, and the best-performing checkpoint in terms of validation loss was found at epoch 250. The model provided here corresponds to the checkpoint from epoch 250. |
|
|
|
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/XEZb34rdFYaeFckDMyCYm.png"> |
|
|