onkarsus13's picture
Update README.md
80941ff verified
---
license: mit
language:
- en
library_name: diffusers
pipeline_tag: image-to-image
tags:
- medical
---
This repository contains a model specifically designed for synthetic data generation of 2D CT-scans, intended solely for research purposes. The base model we employed is Stable-Diffusion-Medium, which has been enhanced using ControlNet, a technique for exerting more precise control over the image generation process.
For pretraining, we utilized the Atlas Dataset from Johns Hopkins University. This dataset provided a comprehensive range of medical imaging data, crucial for the initial training phase of our model. Our aim with this project is to contribute to the medical imaging field by enabling more robust and versatile synthetic data generation.
```
Training Details
Image Size = (128, 128)
Batch_size = 8 x 28 x 12
Computes:
8 x Nvidia-A6000 48GB
```
Code for generation:
```python
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel, UniPCMultistepScheduler, LDMSuperResolutionPipeline
import torch
from PIL import Image
import numpy as np
from transformers import T5Tokenizer
import torch.nn as nn
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
class_dict_atlas = {
0:(0, 0, 0),
1:(255, 60, 0),
2:(255, 60, 232),
3:(134, 79, 117),
4:(125, 0, 190),
5:(117, 200, 191),
6:(230, 91, 101),
7:(255, 0, 155),
8:(75, 205, 155),
9:(100, 37, 200)
}
name_class_dict = {
0:"background",
1:"aorta",
2:"kidney_left",
3:"liver",
4:"postcava",
5:"stomach",
6:"gall_bladder",
7:"kidney_right",
8:"pancreas",
9:"spleen"
}
def rgb_to_onehot(rgb_arr, color_dict=class_dict_atlas):
num_classes = len(color_dict)
shape = rgb_arr.shape[:2]+(num_classes,)
arr = np.zeros( shape, dtype=np.int8 )
for i, cls in enumerate(color_dict):
arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
return arr
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas", torch_dtype=torch.float16, safety_checker=None,
feature_extractor=None,
)
pipe.tokenizer_3 = T5Tokenizer.from_pretrained(
"onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas",
subfolder='tokenizer_3'
)
pipe.to('cuda')
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cuda").manual_seed(1)
images = Image.open("<Give mask image for semantic guidance>")
shape = images.size
npi = np.asarray(images.convert("RGB"))
npi = rgb_to_onehot(npi, ).argmax(-1)
unique_ids = np.unique(npi)
print('CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]))
image = pipe(
prompt='CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]),
control_image=images.convert('RGB'),
height=128,
width=128,
num_inference_steps=50,
generator=generator,
controlnet_conditioning_scale=1.0,
).images[0]
image.resize(shape).save('result.png')
```