File size: 12,384 Bytes
7d21475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#假设读入的数据是nii格式的
# 用于coronal角度数据的读取

import os
from data.base_dataset import BaseDataset, get_params, get_transform
from data.image_folder import make_dataset
from PIL import Image
import numpy as np
import torch
from .mask_extract import process_spine_data, process_spine_data_aug
import json
import nibabel as nib
import random
import torchvision.transforms as transforms
from scipy.ndimage import label, find_objects

def remove_small_connected_components(input_array, min_size):


    # 识别连通域
    structure = np.ones((3, 3), dtype=np.int32)  # 定义连通性结构
    labeled, ncomponents = label(input_array, structure)

    # 遍历所有连通域,如果连通域大小小于阈值,则去除
    for i in range(1, ncomponents + 1):
        if np.sum(labeled == i) < min_size:
            input_array[labeled == i] = 0

    # 如果输入是张量,则转换回张量

    return input_array


class AlignedDataset(BaseDataset):
    """A dataset class for paired image dataset.

    It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.
    """

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        
        # 读取json文件来选择训练集、测试集和验证集
        with open('/home/zhangqi/Project/pytorch-CycleGAN-and-pix2pix-master/data/vertebra_data.json', 'r') as file:
            vertebra_set = json.load(file)
            self.normal_vert_list = []
            self.abnormal_vert_list = []
        # 初始化存储normal和abnormal vertebrae的字典
        self.normal_vert_dict = {}
        self.abnormal_vert_dict = {}

        for patient_vert_id in vertebra_set[opt.phase].keys():
        # 分离patient id和vert id
            patient_id, vert_id = patient_vert_id.rsplit('_',1)
        
        # 判断该vertebra是normal还是abnormal
            if int(vertebra_set[opt.phase][patient_vert_id]) <= 1:
                self.normal_vert_list.append(patient_vert_id)
            # 如果是normal,添加到normal_vert_dict
                if patient_id not in self.normal_vert_dict:
                    self.normal_vert_dict[patient_id] = [vert_id]
                else:
                    self.normal_vert_dict[patient_id].append(vert_id)
            else:
                self.abnormal_vert_list.append(patient_vert_id)
            # 如果是abnormal,添加到abnormal_vert_dict
                if patient_id not in self.abnormal_vert_dict:
                    self.abnormal_vert_dict[patient_id] = [vert_id]
                else:
                    self.abnormal_vert_dict[patient_id].append(vert_id)
            if opt.vert_class=="normal":
                self.vertebra_id = np.array(self.normal_vert_list)
            elif opt.vert_class=="abnormal":
                self.vertebra_id = np.array(self.abnormal_vert_list)
            else:
                print("No vert class is set.")
                self.vertebra_id = None
    
        #self.dir_AB = os.path.join(opt.dataroot, opt.phase)  # get the image directory
        self.dir_AB = opt.dataroot
        #self.dir_mask = os.path.join(opt.dataroot,'mask',opt.phase) 
        #self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size))  # get image paths
        #self.mask_paths = sorted(make_dataset(self.dir_mask, opt.max_dataset_size)) 
        assert(self.opt.load_size >= self.opt.crop_size)   # crop_size should be smaller than the size of loaded image
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
        
    def numpy_to_pil(self,img_np):
        # 假设 img_np 是一个灰度图像的 NumPy 数组,值域在0到255
        if img_np.dtype != np.uint8:
            raise ValueError("NumPy array should have uint8 data type.")
        # 转换为灰度PIL图像
        img_pil = Image.fromarray(img_np)
        return img_pil
    


    # 按照金字塔概率选择一个slice,毕竟中间的slice包含的信息是最多的,因此尽量选择中间的slice
    # 按照金字塔概率选择一个slice,毕竟中间的slice包含的信息是最多的,因此尽量选择中间的slice
    def get_weighted_random_slice(self,z0, z1):
        # 计算新的范围,限制为原来范围的2/3
        range_length = z1 - z0 + 1
        new_range_length = int(range_length * 4 / 5)
    
    # 计算新范围的起始和结束索引
        new_z0 = z0 + (range_length - new_range_length) // 2
        new_z1 = new_z0 + new_range_length - 1
    
    # 计算中心索引
        center_index = (new_z0 + new_z1) // 2
    
    # 计算每个索引的权重
        weights = [1 - abs(i - center_index) / (new_z1 - new_z0) for i in range(new_z0, new_z1 + 1)]
    
        # 归一化权重使得总和为1
        total_weight = sum(weights)
        normalized_weights = [w / total_weight for w in weights]
    
    # 根据权重随机选择一个层
        random_index = np.random.choice(range(new_z0, new_z1 + 1), p=normalized_weights)
        index_ratio = abs(random_index-center_index)/range_length*2
    
        return random_index,index_ratio
    
    def get_valid_slice(self,vert_label, z0, z1,maxheight):
        """
        尝试随机选取一个非空的slice。
        """
        max_attempts = 100  # 设定最大尝试次数以避免无限循环
        attempts = 0
        while attempts < max_attempts:
            slice_index,index_ratio = self.get_weighted_random_slice(z0, z1)
            vert_label[:, slice_index, :] = remove_small_connected_components(vert_label[:, slice_index, :],50)

            if np.sum(vert_label[:, slice_index, :])>50:  # 检查切片是否非空
                coords = np.argwhere(vert_label[:, slice_index, :])
                x1, x2 = min(coords[:, 0]), max(coords[:, 0])
                if x2-x1<maxheight:
                    return slice_index,index_ratio
            attempts += 1
        raise ValueError("Failed to find a non-empty slice after {} attempts.".format(max_attempts))


    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        CAM_folder = '/home/zhangqi/Project/VertebralFractureGrading/heatmap/straighten_coronal/binaryclass_1'
        CAM_path_0 = os.path.join(CAM_folder, self.vertebra_id[index]+'_0.nii.gz')
        CAM_path_1 = os.path.join(CAM_folder, self.vertebra_id[index]+'_1.nii.gz')
        if not os.path.exists(CAM_path_0):
            CAM_path = CAM_path_1
        else:
            CAM_path = CAM_path_0
        CAM_data = nib.load(CAM_path).get_fdata() * 255
        

        patient_id, vert_id = self.vertebra_id[index].rsplit('_', 1)
        vert_id = int(vert_id)
        normal_vert_list = self.normal_vert_dict[patient_id]


        ct_path = os.path.join(self.dir_AB,"CT",self.vertebra_id[index]+'.nii.gz')

        label_path = os.path.join(self.dir_AB,"label",self.vertebra_id[index]+'.nii.gz')

        ct_data = nib.load(ct_path).get_fdata()
        label_data = nib.load(label_path).get_fdata()
        vert_label = np.zeros_like(label_data)
        vert_label[label_data==vert_id]=1
        
        normal_vert_label = label_data.copy()
        if normal_vert_list:
            for normal_vert in normal_vert_list:
                normal_vert_label[normal_vert_label==int(normal_vert)]=255
            normal_vert_label[normal_vert_label!=255]=0
        else:
            normal_vert_label = np.zeros_like(label_data)

        loc = np.where(vert_label)
    
        # 冠状面选择
        z0 = min(loc[1])
        z1 = max(loc[1])
        maxheight = 40
        
        try:
            slice,slice_ratio = self.get_valid_slice(vert_label, z0, z1, maxheight)
            #vert_label[:, :, slice] = remove_small_connected_components(vert_label[:, :, slice],50)
            coords = np.argwhere(vert_label[:, slice, :])
            x1, x2 = min(coords[:, 0]), max(coords[:, 0])
        except ValueError as e:
            print(e)
        width,length = vert_label[:,slice,:].shape
        
        height = x2-x1
        mask_x = (x1+x2)//2
        h2 = maxheight
        if height>h2:
            print(slice,ct_path)
        if mask_x<=h2//2:
            min_x = 0
            max_x = min_x + h2
        elif width-mask_x<=h2/2:
            max_x = width
            min_x = max_x -h2
        else:
            min_x = mask_x-h2//2
            max_x = min_x + h2

        
        # 创建256x256的空白数组
        target_A = np.zeros((256, 256))
        target_B = np.zeros((256, 256))
        target_A1 = np.zeros((256, 256))
        target_normal_vert_label = np.zeros((256, 256))
        target_mask = np.zeros((256, 256))
        target_CAM = np.zeros((256, 256))

# 定位原切片放置的起始和结束列
        start_col = (256 - 64) // 2
        end_col = start_col + 64

# 对于A,直接从ct_data中取切片,然后放置到target_A中
        
        target_B[:min_x, start_col:end_col] = ct_data[(x1-min_x):x1, slice, :]
        target_B[max_x:, start_col:end_col] = ct_data[x2:x2+(width-max_x), slice, :]
        
        target_A[:, start_col:end_col] = ct_data[:,slice,:]

# 处理A1,将label_data中特定ID的位置设为255,其他为0
        A1 = np.zeros_like(label_data[:, slice, :])
        A1[label_data[:, slice, :] == vert_id] = 255
        target_A1[:, start_col:end_col] = A1

# 处理normal_vert_label
        target_normal_vert_label[:min_x, start_col:end_col] = normal_vert_label[(x1-min_x):x1, slice, :]
        target_normal_vert_label[max_x:, start_col:end_col] = normal_vert_label[x2:x2+(width-max_x), slice, :]

# 处理mask
        target_mask[min_x:max_x, start_col:end_col] = 255
        target_CAM[:min_x, start_col:end_col] = CAM_data[(x1-min_x):x1, slice, :]
        target_CAM[max_x:, start_col:end_col] = CAM_data[x2:x2+(width-max_x), slice, :]
        
        target_A = target_A.astype(np.uint8)
        target_B = target_B.astype(np.uint8)
        target_A1 = target_A1.astype(np.uint8)
        target_normal_vert_label = target_normal_vert_label.astype(np.uint8)
        target_mask = target_mask.astype(np.uint8)
        target_CAM = target_CAM.astype(np.uint8)
        

        target_A = self.numpy_to_pil(target_A)
        target_B = self.numpy_to_pil(target_B)
        target_A1 = self.numpy_to_pil(target_A1)
        target_mask = self.numpy_to_pil(target_mask)
        target_normal_vert_label = self.numpy_to_pil(target_normal_vert_label)
        target_CAM = self.numpy_to_pil(target_CAM)
            
        # apply the same transform to both A and B
        A_transform =transforms.Compose([
            transforms.Grayscale(1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        mask_transform = transforms.Compose([
            transforms.ToTensor()
        ])

        target_A = A_transform(target_A)
        target_B = A_transform(target_B)
        target_A1 = mask_transform(target_A1)
        target_mask = mask_transform(target_mask)
        target_normal_vert_label = mask_transform(target_normal_vert_label)
        target_CAM = mask_transform(target_CAM)

        
        
        return {'A': target_A, 'A_mask': target_A1, 'mask':target_mask,'B':target_B,'height':height,'x1':x1,'x2':x2,
                'h2':h2,'slice_ratio':slice_ratio,'normal_vert':target_normal_vert_label,'CAM':target_CAM,'A_paths': ct_path, 'B_paths': ct_path}

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.vertebra_id)