xfetus-ddpm-v2 / README.md
harveymannering's picture
Create README.md
b518a95
|
raw
history blame
1.28 kB
See the following code:
## Model Use
```python
# !pip install diffusers
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
import torch
import matplotlib.pyplot as plt
# load model and scheduler
model_id = "harveymannering/xfetus-ddpm-v2"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
x = torch.randn(1, 3, 128, 128).to(device) # noise
for i, t in tqdm(enumerate(ddpm.scheduler.timesteps)):
model_input = ddpm.scheduler.scale_model_input(x, t)
with torch.no_grad():
# Conditiong on the 'Fetal brain' class (with index 1) because I am most familar
# with what these images look like
class_label = torch.ones(1, dtype=torch.int64)
noise_pred = ddpm.unet(model_input, t, class_label.to(device))["sample"]
x = ddpm.scheduler.step(noise_pred, t, x).prev_sample
plt.imshow(x[0].cpu().detach().numpy())
plt.show()
```
## Example Outputs
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/uxDp-0svPAp2dCmTK36rf.png">
## Training Loss
<img width="608" alt="image" src="https://cdn-uploads.huggingface.co/production/uploads/6349716695ab8cce385f450e/XEZb34rdFYaeFckDMyCYm.png">