--- license: mit language: - en library_name: diffusers pipeline_tag: image-to-image tags: - medical --- This repository contains a model specifically designed for synthetic data generation of 2D CT-scans, intended solely for research purposes. The base model we employed is Stable-Diffusion-Medium, which has been enhanced using ControlNet, a technique for exerting more precise control over the image generation process. For pretraining, we utilized the Atlas Dataset from Johns Hopkins University. This dataset provided a comprehensive range of medical imaging data, crucial for the initial training phase of our model. Our aim with this project is to contribute to the medical imaging field by enabling more robust and versatile synthetic data generation. ``` Training Details Image Size = (128, 128) Batch_size = 8 x 28 x 12 Computes: 8 x Nvidia-A6000 48GB ``` Code for generation: ```python from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel, UniPCMultistepScheduler, LDMSuperResolutionPipeline import torch from PIL import Image import numpy as np from transformers import T5Tokenizer import torch.nn as nn import os os.environ["CUDA_VISIBLE_DEVICES"]="0" class_dict_atlas = { 0:(0, 0, 0), 1:(255, 60, 0), 2:(255, 60, 232), 3:(134, 79, 117), 4:(125, 0, 190), 5:(117, 200, 191), 6:(230, 91, 101), 7:(255, 0, 155), 8:(75, 205, 155), 9:(100, 37, 200) } name_class_dict = { 0:"background", 1:"aorta", 2:"kidney_left", 3:"liver", 4:"postcava", 5:"stomach", 6:"gall_bladder", 7:"kidney_right", 8:"pancreas", 9:"spleen" } def rgb_to_onehot(rgb_arr, color_dict=class_dict_atlas): num_classes = len(color_dict) shape = rgb_arr.shape[:2]+(num_classes,) arr = np.zeros( shape, dtype=np.int8 ) for i, cls in enumerate(color_dict): arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2]) return arr pipe = StableDiffusion3ControlNetPipeline.from_pretrained( "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas", torch_dtype=torch.float16, safety_checker=None, feature_extractor=None, ) pipe.tokenizer_3 = T5Tokenizer.from_pretrained( "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas", subfolder='tokenizer_3' ) pipe.to('cuda') pipe.enable_model_cpu_offload() generator = torch.Generator(device="cuda").manual_seed(1) images = Image.open("") shape = images.size npi = np.asarray(images.convert("RGB")) npi = rgb_to_onehot(npi, ).argmax(-1) unique_ids = np.unique(npi) print('CT image containg '+" ".join([name_class_dict[i] for i in unique_ids])) image = pipe( prompt='CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]), control_image=images.convert('RGB'), height=128, width=128, num_inference_steps=50, generator=generator, controlnet_conditioning_scale=1.0, ).images[0] image.resize(shape).save('result.png') ```