harveymannering
commited on
Commit
•
120bc8b
1
Parent(s):
b518a95
Update README.md
Browse files
README.md
CHANGED
@@ -1,36 +1,62 @@
|
|
1 |
See the following code:
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
```python
|
6 |
-
# !pip install diffusers
|
7 |
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
8 |
import torch
|
9 |
import matplotlib.pyplot as plt
|
|
|
10 |
|
|
|
|
|
11 |
|
12 |
# load model and scheduler
|
13 |
model_id = "harveymannering/xfetus-ddpm-v2"
|
14 |
-
ddpm = DDPMPipeline.from_pretrained(model_id)
|
|
|
15 |
|
|
|
16 |
x = torch.randn(1, 3, 128, 128).to(device) # noise
|
17 |
-
for i, t in
|
18 |
model_input = ddpm.scheduler.scale_model_input(x, t)
|
19 |
with torch.no_grad():
|
20 |
-
# Conditiong on the 'Fetal brain' class (with index 1)
|
21 |
-
# with what these images look like
|
22 |
class_label = torch.ones(1, dtype=torch.int64)
|
23 |
noise_pred = ddpm.unet(model_input, t, class_label.to(device))["sample"]
|
24 |
x = ddpm.scheduler.step(noise_pred, t, x).prev_sample
|
25 |
|
26 |
-
|
27 |
-
plt.
|
28 |
```
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/uxDp-0svPAp2dCmTK36rf.png">
|
33 |
|
34 |
-
|
|
|
|
|
35 |
|
36 |
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/XEZb34rdFYaeFckDMyCYm.png">
|
|
|
1 |
See the following code:
|
2 |
|
3 |
+
# Model Use
|
4 |
+
|
5 |
+
This model is an class condition DDPM that can generate ultra sound images. It is based on the diffusion model from the [Towards Realistic Ultrasound Fetal Brain Imaging Synthesis](https://arxiv.org/abs/2304.03941) paper. The dataset used to train the model is this [FETAL_PLANES_DB dataset](https://zenodo.org/record/3904280). The classes that can be generated and the associated integer labels are
|
6 |
+
|
7 |
+
| Label | Class |
|
8 |
+
| -------- | ------- |
|
9 |
+
| 0 | Fetal abdomen |
|
10 |
+
| 1 | Fetal brain |
|
11 |
+
| 2 | Fetal femur |
|
12 |
+
| 3 | Fetal thorax |
|
13 |
+
| 4 | Maternal cervix |
|
14 |
+
| 5 | Other |
|
15 |
+
|
16 |
+
When generating images simply provide the label of you chosen class as an argument to the UNet.
|
17 |
+
|
18 |
+
Below you will see code that allows you to load this model, generate an image, and display it.
|
19 |
|
20 |
```python
|
21 |
+
# !pip install --upgrade diffusers transformers accelerate scipy ftfy safetensors
|
22 |
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
23 |
import torch
|
24 |
import matplotlib.pyplot as plt
|
25 |
+
import numpy as np
|
26 |
|
27 |
+
# Are we using a GPU or CPU?
|
28 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
|
30 |
# load model and scheduler
|
31 |
model_id = "harveymannering/xfetus-ddpm-v2"
|
32 |
+
ddpm = DDPMPipeline.from_pretrained(model_id)
|
33 |
+
ddpm.to(device)
|
34 |
|
35 |
+
# Generate a single image
|
36 |
x = torch.randn(1, 3, 128, 128).to(device) # noise
|
37 |
+
for i, t in enumerate(ddpm.scheduler.timesteps):
|
38 |
model_input = ddpm.scheduler.scale_model_input(x, t)
|
39 |
with torch.no_grad():
|
40 |
+
# Conditiong on the 'Fetal brain' class (with index 1)
|
|
|
41 |
class_label = torch.ones(1, dtype=torch.int64)
|
42 |
noise_pred = ddpm.unet(model_input, t, class_label.to(device))["sample"]
|
43 |
x = ddpm.scheduler.step(noise_pred, t, x).prev_sample
|
44 |
|
45 |
+
# Display image
|
46 |
+
plt.imshow(np.transpose(x[0].cpu().detach().numpy(), (1,2,0)) + 0.5)
|
47 |
```
|
48 |
|
49 |
+
# Example Outputs
|
50 |
+
|
51 |
+
The images from the original dataset as well as synthetic images. The following preprocessing/augmentation steps were applied to every image:
|
52 |
+
1. Random Horizontal Flip
|
53 |
+
2. Random Rotation (±45°)
|
54 |
+
3. Resize with Bicubic Interpolation
|
55 |
|
56 |
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/uxDp-0svPAp2dCmTK36rf.png">
|
57 |
|
58 |
+
# Training Loss
|
59 |
+
|
60 |
+
Baseline was trained on images from the "Voluson E6" machine only. Training and validation loss are given below. Checkpoints were saved every 50 epochs and the best-performing checkpoint on the validation loss was at epoch 250. The model provided here is the checkpoint from epoch 250.
|
61 |
|
62 |
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/XEZb34rdFYaeFckDMyCYm.png">
|