diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..3e277f2d63fab1903d0bd41b25cb5759b7296e03 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..34d288942f1afcf67ebe637b50eab137b3340ddf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,70 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +checkpoints/sid_1-500_m3lteacher.pth filter=lfs diff=lfs merge=lfs -text +checkpoints/sid_1-500_mtteacher.pth filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text +examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text +examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text +classcolors.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index ff389ebd62ff95f1b9af84b1cc1d27344888b507..27650c5fd1224139329efe85d01ffc2a48c03c72 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ --- title: M3L -emoji: 🐒 -colorFrom: pink +emoji: πŸ“š +colorFrom: purple colorTo: gray sdk: gradio sdk_version: 3.23.0 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6301a42fa426017e1907179b1fc7f4173abed4ef --- /dev/null +++ b/app.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import gradio as gr +import numpy as np +import os +import random +import pickle as pkl +from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch +from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion +from datasets.preprocessors import RGBDValPre +from utils.constants import Constants as C + +class Arguments: + def __init__(self, ratio): + self.ratio = ratio + self.masking_ratio = 1.0 + +colors = pkl.load(open('./colors.pkl', 'rb')) +args = Arguments(ratio = 0.8) + +mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False) +mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth' +mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu'))) +mtmodel.eval() + +m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False) +m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth' +m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu'))) +m3lmodel.eval() + + + +class MaskStudentTeacher(nn.Module): + + def __init__(self, student, teacher, ema_alpha, mode = 'train'): + super(MaskStudentTeacher, self).__init__() + self.student = student + self.teacher = teacher + self.teacher = self._detach_teacher(self.teacher) + self.ema_alpha = ema_alpha + self.mode = mode + def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs): + ret = [] + if student: + if self.mode == 'train': + ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs)) + elif self.mode == 'val': + ret.append(self.student(data, mask = False, **kwargs)) + else: + raise Exception('Mode not supported') + if teacher: + ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned + return ret + def _detach_teacher(self, model): + for param in model.parameters(): + param.detach_() + return model + def update_teacher_models(self, global_step): + alpha = min(1 - 1 / (global_step + 1), self.ema_alpha) + for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): + ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + return + def copy_student_to_teacher(self): + for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): + ema_param.data.mul_(0).add_(param.data) + return + def get_params(self): + student_params = self.student.get_params() + teacher_params = self.teacher.get_params() + return student_params + + +def preprocess_data(rgb, depth, dataset_settings): + #RGB: np.array, RGB + #Depth: np.array, minmax normalized, *255 + preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + rgb, depth = preprocess(rgb, depth) + if rgb is not None: + rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float() + if depth is not None: + depth = torch.from_numpy(np.ascontiguousarray(depth)).float() + return rgb, depth + + +def visualize(colors, pred, num_classes, dataset_settings): + pred = pred.transpose(1, 2, 0) + predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3)) + for i in range(num_classes): + color = colors[i] + predvis = np.where(pred == i, color, predvis) + predvis /= 255.0 + predvis = predvis[:,:,::-1] + return predvis + +def predict(rgb, depth, check): + dataset_settings = {} + dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540 + dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540 + + rgb, depth = preprocess_data(rgb, depth, dataset_settings) + if rgb is not None: + rgb = rgb.unsqueeze(dim = 0) + if depth is not None: + depth = depth.unsqueeze(dim = 0) + ret = [None, None, './classcolors.png'] + if "Mean Teacher" in check: + if rgb is None: + rgb = torch.zeros_like(depth) + if depth is None: + depth = torch.zeros_like(rgb) + scores = mtmodel([rgb, depth])[2] + scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) + prob = scores.detach() + _, pred = torch.max(prob, dim=1) + pred = pred.numpy() + predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) + ret[0] = predvis + if "M3L" in check: + mask = False + masking_branch = None + if rgb is None: + mask = True + masking_branch = 0 + if depth is None: + mask = True + masking_branch = 1 + scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2] + scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) + prob = scores.detach() + _, pred = torch.max(prob, dim=1) + pred = pred.numpy() + predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) + ret[1] = predvis + + return ret + +imgs = os.listdir('./examples/rgb') +random.shuffle(imgs) +examples = [] +for img in imgs: + examples.append([ + './examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"] + ]) + +with gr.Blocks() as demo: + with gr.Row(): + rgbinput = gr.Image(label="RGB Input").style(height=256, width=256) + depthinput = gr.Image(label="Depth Input").style(height=256, width=256) + with gr.Row(): + modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:") + with gr.Row(): + submit_btn = gr.Button("Submit") + with gr.Row(): + mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384) + m3loutput = gr.Image(label="M3L Output").style(height=384, width=384) + classnameouptut = gr.Image(label="Classes").style(height=384, width=384) + with gr.Row(): + examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck]) + submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut]) + +demo.launch() diff --git a/arial.ttf b/arial.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9511009574d03893e04b78b259c3a2fb15112543 Binary files /dev/null and b/arial.ttf differ diff --git a/checkpoints/sid_1-500_m3lteacher.pth b/checkpoints/sid_1-500_m3lteacher.pth new file mode 100644 index 0000000000000000000000000000000000000000..8a1f11ec6247d176e54ec199dce60cc9b71013ac --- /dev/null +++ b/checkpoints/sid_1-500_m3lteacher.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5a23d7e2697b44b18e368e01353c328b13055a05a1cb0946ffb95b692d6facd +size 99192724 diff --git a/checkpoints/sid_1-500_mtteacher.pth b/checkpoints/sid_1-500_mtteacher.pth new file mode 100644 index 0000000000000000000000000000000000000000..bd5030bd1b649c3257fc693392ccd1ab0d684204 --- /dev/null +++ b/checkpoints/sid_1-500_mtteacher.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7eb24d6275e15376c40ec526281d550e4842ffa71aaa7af58fea54cbf56c2eeb +size 99186911 diff --git a/classcolors.png b/classcolors.png new file mode 100644 index 0000000000000000000000000000000000000000..f595bc201817e96b3b1295d7ecbb4c487acb4501 --- /dev/null +++ b/classcolors.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54ab7bebebd4982336252535488ad47d42751c06652eb4c47fb0af47c7880aba +size 38369 diff --git a/colors.pkl b/colors.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2b45b4be220857eb5ab32185efd27ce5b3919eb5 --- /dev/null +++ b/colors.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c03050cb753b0802781f0ce92893ac22129c15724dd4ece6f4b9b4a352db591 +size 2342 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/__pycache__/__init__.cpython-36.pyc b/datasets/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bbf0fa219a462c8ec82ca8b35bb714d1e97c8a Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-36.pyc differ diff --git a/datasets/__pycache__/__init__.cpython-38.pyc b/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21447a1f9524302980bb51070a6b9be970607ba8 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/datasets/__pycache__/__init__.cpython-39.pyc b/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c66035a6ee01dd1be1ef4138fd1a2be579f0ff2 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/datasets/__pycache__/base_dataset.cpython-36.pyc b/datasets/__pycache__/base_dataset.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..547414610982110cdcdbda755f95944e6ea7b92e Binary files /dev/null and b/datasets/__pycache__/base_dataset.cpython-36.pyc differ diff --git a/datasets/__pycache__/base_dataset.cpython-38.pyc b/datasets/__pycache__/base_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24e28b21e5b3d15d02a63b940aa13f48b654ed1 Binary files /dev/null and b/datasets/__pycache__/base_dataset.cpython-38.pyc differ diff --git a/datasets/__pycache__/base_dataset.cpython-39.pyc b/datasets/__pycache__/base_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9566bfc5f3bc4897d32e48d5114abe1e41ff9b22 Binary files /dev/null and b/datasets/__pycache__/base_dataset.cpython-39.pyc differ diff --git a/datasets/__pycache__/citysundepth.cpython-36.pyc b/datasets/__pycache__/citysundepth.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec68e1e02423a4b4cc6812a70ab965832c252646 Binary files /dev/null and b/datasets/__pycache__/citysundepth.cpython-36.pyc differ diff --git a/datasets/__pycache__/citysundepth.cpython-39.pyc b/datasets/__pycache__/citysundepth.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b8738c8ccccd5a191755af4dd33a14e0484858 Binary files /dev/null and b/datasets/__pycache__/citysundepth.cpython-39.pyc differ diff --git a/datasets/__pycache__/citysunrgb.cpython-36.pyc b/datasets/__pycache__/citysunrgb.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d99ba66c9819a1f6c66b87d1d7cf7e6d5bdbf6b1 Binary files /dev/null and b/datasets/__pycache__/citysunrgb.cpython-36.pyc differ diff --git a/datasets/__pycache__/citysunrgb.cpython-38.pyc b/datasets/__pycache__/citysunrgb.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dea671a22dfd0460373e1f4a41e5fe9a9c24a814 Binary files /dev/null and b/datasets/__pycache__/citysunrgb.cpython-38.pyc differ diff --git a/datasets/__pycache__/citysunrgb.cpython-39.pyc b/datasets/__pycache__/citysunrgb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd2fefcfc096f08c3d5e0afb1d5441516c91a35f Binary files /dev/null and b/datasets/__pycache__/citysunrgb.cpython-39.pyc differ diff --git a/datasets/__pycache__/citysunrgbd.cpython-36.pyc b/datasets/__pycache__/citysunrgbd.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..164ad122fd0fb7d58b81fd30382dfcd69ca502bb Binary files /dev/null and b/datasets/__pycache__/citysunrgbd.cpython-36.pyc differ diff --git a/datasets/__pycache__/citysunrgbd.cpython-38.pyc b/datasets/__pycache__/citysunrgbd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7cefa2eb96b687fe01695aed1622a9185819ae Binary files /dev/null and b/datasets/__pycache__/citysunrgbd.cpython-38.pyc differ diff --git a/datasets/__pycache__/get_dataset.cpython-36.pyc b/datasets/__pycache__/get_dataset.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e812c9bef0bec9851900d5c1314b7d9b14a22d21 Binary files /dev/null and b/datasets/__pycache__/get_dataset.cpython-36.pyc differ diff --git a/datasets/__pycache__/get_dataset.cpython-39.pyc b/datasets/__pycache__/get_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9de70f3ceb2a062c94353ab0b7779153e358b69b Binary files /dev/null and b/datasets/__pycache__/get_dataset.cpython-39.pyc differ diff --git a/datasets/__pycache__/preprocessors.cpython-36.pyc b/datasets/__pycache__/preprocessors.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbcf9d5b592f380700a6420bcd38983cb7415e0d Binary files /dev/null and b/datasets/__pycache__/preprocessors.cpython-36.pyc differ diff --git a/datasets/__pycache__/preprocessors.cpython-38.pyc b/datasets/__pycache__/preprocessors.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5bb7d04cbfa19005c3c98d7d060ec4faf46021f Binary files /dev/null and b/datasets/__pycache__/preprocessors.cpython-38.pyc differ diff --git a/datasets/__pycache__/tfnyu.cpython-36.pyc b/datasets/__pycache__/tfnyu.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..776f4f817cb2b73cbdae56397e5cd92451d1ecda Binary files /dev/null and b/datasets/__pycache__/tfnyu.cpython-36.pyc differ diff --git a/datasets/base_dataset.py b/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..436623437dfe7a925b56a617a50a130d62192b4e --- /dev/null +++ b/datasets/base_dataset.py @@ -0,0 +1,128 @@ +import torch +import torch.utils.data as data +import numpy as np +import cv2 +from PIL import Image +from utils.img_utils import pad_image_to_shape + +class BaseDataset(data.Dataset): + + def __init__(self, dataset_settings, mode, unsupervised): + self._mode = mode + self.unsupervised = unsupervised + self._rgb_path = dataset_settings['rgb_root'] + self._depth_path = dataset_settings['depth_root'] + self._gt_path = dataset_settings['gt_root'] + self._train_source = dataset_settings['train_source'] + self._eval_source = dataset_settings['eval_source'] + self.modalities = dataset_settings['modalities'] + # self._file_length = dataset_settings['max_samples'] + self._required_length = dataset_settings['required_length'] + self._file_names = self._get_file_names(mode) + self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __len__(self): + if self._required_length is not None: + return self._required_length + return len(self._file_names) # when model == "val" + + def _get_file_names(self, mode): + assert mode in ['train', 'val'] + source = self._train_source + if mode == "val": + source = self._eval_source + + file_names = [] + with open(source) as f: + files = f.readlines() + + for item in files: + names = self._process_item_names(item) + file_names.append(names) + + if mode == "val": + return file_names + elif self._required_length <= len(file_names): + return file_names[:self._required_length] + else: + return self._construct_new_file_names(file_names, self._required_length) + + def _construct_new_file_names(self, file_names, length): + assert isinstance(length, int) + files_len = len(file_names) + + new_file_names = file_names * (length // files_len) #length % files_len items remaining + + rand_indices = torch.randperm(files_len).tolist() + new_indices = rand_indices[:length % files_len] + + new_file_names += [file_names[i] for i in new_indices] + + return new_file_names + + def _process_item_names(self, item): + item = item.strip() + item = item.split('\t') + num_modalities = len(self.modalities) + num_items = len(item) + names = {} + if not self.unsupervised: + assert num_modalities + 1 == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}" + item[0] + for i, modality in enumerate(self.modalities): + names[modality] = item[i] + names['gt'] = item[-1] + else: + assert num_modalities == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}" + for i, modality in enumerate(self.modalities): + names[modality] = item[i] + names['gt'] = None + + return names + + def _open_rgb(self, rgb_path, dtype = None): + bgr = cv2.imread(rgb_path, cv2.IMREAD_COLOR) #cv2 reads in BGR format, HxWxC + rgb = np.array(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), dtype=dtype) #Pretrained PyTorch model accepts image in RGB + return rgb + + def _open_depth(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels + img_arr = np.array(Image.open(depth_path)) + if len(img_arr.shape) == 2: # grayscale + img_arr = np.array(np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0), dtype = dtype) + img_arr = (img_arr - img_arr.min()) * 255.0 / (img_arr.max() - img_arr.min()) + return img_arr + + def _open_depth_tf_nyu(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels + img_arr = np.array(Image.open(depth_path)) + if len(img_arr.shape) == 2: # grayscale + img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0) + return img_arr + + def _open_gt(self, gt_path, dtype = None): + return np.array(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE), dtype=dtype) + + def slide_over_image(self, img, crop_size, stride_rate): + H, W, C = img.shape + long_size = H if H > W else W + output = [] + if long_size <= min(crop_size[0], crop_size[1]): + raise Exception("Crop size is greater than the image size itself. Not handeled right now") + + else: + stride_0 = int(np.ceil(crop_size[0] * stride_rate)) + stride_1 = int(np.ceil(crop_size[1] * stride_rate)) + r_grid = int(np.ceil((H - crop_size[0]) / stride_0)) + 1 + c_grid = int(np.ceil((W - crop_size[1]) / stride_1)) + 1 + + for grid_yidx in range(r_grid): + for grid_xidx in range(c_grid): + s_x = grid_xidx * stride_1 + s_y = grid_yidx * stride_0 + e_x = min(s_x + crop_size[1], W) + e_y = min(s_y + crop_size[0], H) + s_x = e_x - crop_size[1] + s_y = e_y - crop_size[0] + img_sub = img[s_y:e_y, s_x: e_x, :] + img_sub, margin = pad_image_to_shape(img_sub, crop_size, cv2.BORDER_CONSTANT, value=0) + output.append((img_sub, np.array([s_y, e_y, s_x, e_x]), margin)) + + return output diff --git a/datasets/citysunrgbd.py b/datasets/citysunrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..473e02a841518e583d7a02b3b83639353eb5a4f1 --- /dev/null +++ b/datasets/citysunrgbd.py @@ -0,0 +1,67 @@ +import torch +import numpy as np + +from datasets.base_dataset import BaseDataset + + +class CityScapesSunRGBD(BaseDataset): + + def __init__(self, dataset_settings, mode, unsupervised, preprocess, sliding = False, stride_rate = None): + super(CityScapesSunRGBD, self).__init__(dataset_settings, mode, unsupervised) + self.preprocess = preprocess + self.sliding = sliding + self.stride_rate = stride_rate + if self.sliding and self._mode == 'train': + print("Ensure correct preprocessing is being done!") + + def __getitem__(self, index): + # if self._file_length is not None: + # names = self._construct_new_file_names(self._file_length)[index] + # else: + # names = self._file_names[index] + names = self._file_names[index] + rgb_path = self._rgb_path+names['rgb'] + depth_path = self._rgb_path+names['depth'] + if not self.unsupervised: + gt_path = self._gt_path+names['gt'] + item_name = names['rgb'].split("/")[-1].split(".")[0] + + rgb = self._open_rgb(rgb_path) + depth = self._open_depth(depth_path) + gt = None + if not self.unsupervised: + gt = self._open_gt(gt_path) + + if not self.sliding: + if self.preprocess is not None: + rgb, depth, gt = self.preprocess(rgb, depth, gt) + + if self._mode in ['train', 'val']: + rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float() + depth = torch.from_numpy(np.ascontiguousarray(depth)).float() + if gt is not None: + gt = torch.from_numpy(np.ascontiguousarray(gt)).long() + else: + raise Exception(f"{self._mode} not supported in CityScapesSunRGB") + + # output_dict = dict(rgb=rgb, fn=str(item_name), + # n=len(self._file_names)) + output_dict = dict(data=[rgb, depth], name = item_name) + if gt is not None: + output_dict['gt'] = gt + return output_dict + + else: + sliding_ouptut = self.slide_over_image(rgb, self.model_input_shape, self.stride_rate) + output_dict = {} + if self._mode in ['train', 'val']: + if gt is not None: + gt = torch.from_numpy(np.ascontiguousarray(gt)).long() + output_dict['gt'] = gt + output_dict['sliding_output'] = [] + for img_sub, pos, margin in sliding_ouptut: + if self.preprocess is not None: + img_sub, _ = self.preprocess(img_sub, None) + img_sub = torch.from_numpy(np.ascontiguousarray(img_sub)).float() + output_dict['sliding_output'].append(([img_sub], pos, margin)) + return output_dict diff --git a/datasets/get_dataset.py b/datasets/get_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6fd2cfdb4502669deeb8b089ed59a53023e481 --- /dev/null +++ b/datasets/get_dataset.py @@ -0,0 +1,146 @@ +import torch +from torch.utils.data import DataLoader +from datasets.citysundepth import CityScapesSunDepth +from datasets.citysunrgb import CityScapesSunRGB +from datasets.citysunrgbd import CityScapesSunRGBD +from datasets.preprocessors import DepthTrainPre, DepthValPre, NYURGBDTrainPre, NYURGBDValPre, RGBDTrainPre, RGBDValPre, RGBTrainPre, RGBValPre +from datasets.tfnyu import TFNYU +from utils.constants import Constants as C + +def get_dataset(args): + datasetClass = None + if args.data == "nyudv2": + return TFNYU + if args.data == "city" or args.data == "sunrgbd" or args.data == 'stanford_indoor': + if len(args.modalities) == 1 and args.modalities[0] == 'rgb': + datasetClass = CityScapesSunRGB + elif len(args.modalities) == 1 and args.modalities[0] == 'depth': + datasetClass = CityScapesSunDepth + elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': + datasetClass = CityScapesSunRGBD + else: + raise Exception(f"{args.modalities} not configured in get_dataset function.") + else: + raise Exception(f"{args.data} not configured in get_dataset function.") + return datasetClass + +def get_preprocessors(args, dataset_settings, mode): + if args.data == "nyudv2" and len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': + if mode == 'train': + return NYURGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + elif mode == 'val': + return NYURGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + + if len(args.modalities) == 1 and args.modalities[0] == 'rgb': + if mode == 'train': + return RGBTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + elif mode == 'val': + return RGBValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + else: + return Exception("%s mode not defined" % mode) + elif len(args.modalities) == 1 and args.modalities[0] == 'depth': + if mode == 'train': + return DepthTrainPre(dataset_settings) + elif mode == 'val': + return DepthValPre(dataset_settings) + else: + return Exception("%s mode not defined" % mode) + elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': + if mode == 'train': + return RGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + elif mode == 'val': + return RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) + else: + return Exception("%s mode not defined" % mode) + else: + raise Exception("%s not configured for preprocessing" % args.modalities) + +def get_train_loader(datasetClass, args, train_source, unsupervised = False): + dataset_settings = {'rgb_root': args.rgb_root, + 'gt_root': args.gt_root, + 'depth_root': args.depth_root, + 'train_source': train_source, + 'eval_source': args.eval_source, + 'required_length': args.total_train_imgs, #Every dataloader will have Total Train Images / batch size iterations to be consistent + # 'max_samples': args.max_samples, #Every dataloader will have Total Train Images / batch size iterations to be consistent + 'train_scale_array': args.train_scale_array, + 'image_height': args.image_height, + 'image_width': args.image_width, + 'modalities': args.modalities} + + preprocessing = get_preprocessors(args, dataset_settings, "train") + train_dataset = datasetClass(dataset_settings, "train", unsupervised, preprocessing) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = args.world_size, rank = args.rank) + if unsupervised and "unsup_batch_size" in args: + batch_size = args.unsup_batch_size + else: + batch_size = args.batch_size + train_loader = DataLoader(train_dataset, + batch_size = args.batch_size // args.world_size, + num_workers = args.num_workers, + drop_last = True, + shuffle = False, + sampler = train_sampler) + return train_loader + +def get_val_loader(datasetClass, args): + dataset_settings = {'rgb_root': args.rgb_root, + 'gt_root': args.gt_root, + 'depth_root': args.depth_root, + 'train_source': None, + 'eval_source': args.eval_source, + 'required_length': None, + 'max_samples': None, + 'train_scale_array': args.train_scale_array, + 'image_height': args.image_height, + 'image_width': args.image_width, + 'modalities': args.modalities} + if args.data == 'sunrgbd': + eval_sources = [] + for shape in ['427_561', '441_591', '530_730', '531_681']: + eval_sources.append(dataset_settings['eval_source'].split('.')[0] + '_' + shape + '.txt') + else: + eval_sources = [args.eval_source] + + preprocessing = get_preprocessors(args, dataset_settings, "val") + if args.sliding_eval: + collate_fn = _sliding_collate_fn + else: + collate_fn = None + + val_loaders = [] + for eval_source in eval_sources: + dataset_settings['eval_source'] = eval_source + val_dataset = datasetClass(dataset_settings, "val", False, preprocessing, args.sliding_eval, args.stride_rate) + if args.rank is not None: #DDP Evaluation + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = args.world_size, rank = args.rank) + batch_size = args.val_batch_size // args.world_size + else: #DP Evaluation + val_sampler = None + batch_size = args.val_batch_size + + val_loader = DataLoader(val_dataset, + batch_size = batch_size, + num_workers = 4, + drop_last = False, + shuffle = False, + collate_fn = collate_fn, + sampler = val_sampler) + val_loaders.append(val_loader) + return val_loaders + + +def _sliding_collate_fn(batch): + gt = torch.stack([b['gt'] for b in batch]) + sliding_output = [] + num_modalities = len(batch[0]['sliding_output'][0][0]) + for i in range(len(batch[0]['sliding_output'])): #i iterates over positions + imgs = [torch.stack([b['sliding_output'][i][0][m] for b in batch]) for m in range(num_modalities)] + pos = batch[0]['sliding_output'][i][1] + pos_compare = [(b['sliding_output'][i][1] == pos).all() for b in batch] + assert all(pos_compare), f"Position not same for all points in the batch: {pos_compare}, {[b['sliding_output'][i][1] for b in batch]}" + margin = batch[0]['sliding_output'][i][2] + margin_compare = [(b['sliding_output'][i][2] == margin).all() for b in batch] + assert all(margin_compare), f"Margin not same for all points in the batch: {margin_compare}, {[b['sliding_output'][i][2] for b in batch]}" + sliding_output.append((imgs, pos, margin)) + return {"gt": gt, "sliding_output": sliding_output} \ No newline at end of file diff --git a/datasets/preprocessors.py b/datasets/preprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..014cb700d82f5548b45313a72b3641b8b1882e22 --- /dev/null +++ b/datasets/preprocessors.py @@ -0,0 +1,144 @@ +from utils.img_utils import normalizedepth, random_crop_pad_to_shape, random_mirror, random_scale, normalize, resizedepth, resizergb, tfnyu_normalizedepth + +class RGBTrainPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.train_scale_array = dataset_settings['train_scale_array'] + self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, gt): + transformed_dict = random_mirror({"rgb":rgb, "gt":gt}) + if self.train_scale_array is not None: + transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1])) + + transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1 + rgb = transformed_dict['rgb'] + gt = transformed_dict['gt'] + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return rgb, gt + + +class RGBValPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, gt): + rgb = resizergb(rgb, self.model_input_shape) + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return rgb, gt + + +class RGBDTrainPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.train_scale_array = dataset_settings['train_scale_array'] + self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, depth, gt): + transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt}) + if self.train_scale_array is not None: + transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1])) + + transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1 + rgb = transformed_dict['rgb'] + depth = transformed_dict['depth'] + gt = transformed_dict['gt'] + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + depth = normalizedepth(depth) + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return rgb, depth, gt + + +class RGBDValPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, depth): + if rgb is not None: + rgb = resizergb(rgb, self.model_input_shape) + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + if depth is not None: + depth = resizedepth(depth, self.model_input_shape) + depth = normalizedepth(depth) + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + + return rgb, depth + + +class NYURGBDTrainPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.train_scale_array = dataset_settings['train_scale_array'] + self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, depth, gt): + transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt}) + if self.train_scale_array is not None: + transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1])) + + transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1 + rgb = transformed_dict['rgb'] + depth = transformed_dict['depth'] + gt = transformed_dict['gt'] + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + depth = tfnyu_normalizedepth(depth) + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return rgb, depth, gt + + +class NYURGBDValPre(object): + def __init__(self, pytorch_mean, pytorch_std, dataset_settings): + self.pytorch_mean = pytorch_mean + self.pytorch_std = pytorch_std + self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, rgb, depth, gt): + rgb = resizergb(rgb, self.model_input_shape) + depth = resizedepth(depth, self.model_input_shape) + rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std) + depth = tfnyu_normalizedepth(depth) + rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return rgb, depth, gt + + +class DepthTrainPre(object): + def __init__(self, dataset_settings): + self.train_scale_array = dataset_settings['train_scale_array'] + self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, depth, gt): + transformed_dict = random_mirror({"depth": depth, "gt":gt}) + if self.train_scale_array is not None: + transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (depth.shape[0], depth.shape[1])) + + transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['depth'].shape[:2], self.crop_size) #Makes gt HxWx1 + depth = transformed_dict['depth'] + gt = transformed_dict['gt'] + depth = normalizedepth(depth) + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return depth, gt + + +class DepthValPre(object): + def __init__(self, dataset_settings): + self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width']) + + def __call__(self, depth, gt): + depth = resizedepth(depth, self.model_input_shape) + depth = normalizedepth(depth) + depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW + return depth, gt \ No newline at end of file diff --git a/examples/.DS_Store b/examples/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b1dacde6afbd4af6abb1d165046df19ddc61fdf8 Binary files /dev/null and b/examples/.DS_Store differ diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..a5f23e0890024be3cd744b3df622a4f73ba8ab46 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55885dc912a9accfa7bc492e065b323b1a52f82a5117b41fc8efd4a2b12adaf9 +size 83603 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..95255721c1c5a85e62b38227960d43ea9ba89f09 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bf577c404957a98fedd04fb458c4361e181b735c1eeef60939c643b0c3a60a3 +size 64570 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..469a5ff3630a8c0cd1f107e0f2820e586ce9b501 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96cf8016c66212f327fbda0163b8b03a2daf3c7e6c28ae02faf83bd2e3f11cdb +size 232749 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..e46ef7b9cbdcd0f9ed2cf3e5195b6e6f92111431 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d11dbca1fa89c97e1eaf8302dd37caf49cbe94f9fd6d8331c61c29d2f786c4b3 +size 77404 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..6596d2b5605bc433cfe81bed190b10d02f01109d --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a56e0ea6823606d15bf04c2a642be8df5485883f5be2823eb020d7ba6f2802f1 +size 276239 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..9ed9e9110049f15021f5349aed2189225585e065 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98b9da600b500edd19c89cd0ca0b454a49aa964928d88a7836b3abcf04425cb6 +size 193762 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..008b17f5e73793aafeb95ea104a572cc8ba9d5ac --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36db03a470d9826b6dfdd49a804c3d23434ab88d231841c2afb74ef33e69c4a2 +size 168746 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..d6e665a17d4d1d7ac77ae085d943b64347d815b4 --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:411bc3b1f22ddc361cf8c89cec596f12818e930a9cb3d9e39bfb509c1b9e46f4 +size 192627 diff --git a/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..64576d996d27825ac858ddcbcc1f4439dd70224b --- /dev/null +++ b/examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0004251bc5fc96e0df7eeb677444aa50b8d8d569175c408ce48b71792c32b64d +size 255860 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..68d25c9c86ee042bda89123e72575b81cadb92c3 --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9aabf2840f638a9c52ae8151cb3677c8786e2350574ecf21c59f3fd44fe9bf91 +size 170890 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..db5f5b863075bcb4db8775485dcc1f811258e5ab --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fda21d17a465af9cb332f07fd8fc9bb255a03a48bd4b661a5ca789784dbcad7 +size 273310 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..9c65510abe1b1c20ce85f7fa3079494d75e88b1d --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ec6050af0b91ab3fd7784deb0e005c89eec40a40e51a9936ca499bbb9775fa3 +size 250824 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..b92741150e54ed9d4d0ac15267f888345bd87c7f --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3800cdc565340cba8097902bf374506b898ef2f59864697a89ea2817e714be7 +size 105998 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..c68a9f54a4b885546582db0beb6518b12e4cadd1 --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc28a5d5e9003fe651e5a4da7683512b8d4f9da0b447fa024dfd9ffd9f8ff6c2 +size 96347 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..358a1abb1bbca00e05ffe6794c86a58da7f7ca9f --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b25e48bd4afdb9e9d60e1637065caa4ec6a3f56bbc5bd3be6255d32d63d724b +size 150209 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..e8fe758164d06e15fc9bb77ed36a2218c304ceaa --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abef7219c2a2aca39183172aff749593a6b10a7cc6f7d0c6f47b3a0eb2186832 +size 122678 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..b32334f7d5865180549d5862078c9049a0bdc292 --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7346e13826370bb01b2aee6ee4396d922dad43af5b714f9d0993faf6ed3a340a +size 82017 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..139635b3c7db53c7bf9a33a94f2ac46456bacc16 --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:428275fb70796c4d129d91243a587575e511313cbaf7146bfd857da3da5b4250 +size 205466 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..fb5fb9788f6e91d8641496babacab624a45b66e7 --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b918a3cf38d185779ac732df994bf8977e1cbdc891a5f12593197a78b75fd9cf +size 86513 diff --git a/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..9374ee7f1ffb107439b971dba4bd2538eee8731a --- /dev/null +++ b/examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7968b71a48074f9edb876f650f7fc4ac801db80084a38ac79eef735a777cd6c5 +size 120437 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..c986ee5b4f9370cb681de28088cffeecfea48d3b --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8409cca46cc81c7fe73c89e5a1f6fe91bbc41fefea47f643bc89a281bd9ef779 +size 288734 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..ad1419eb14c7f806bd19af82869ebbd8698fadf9 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddf4eebed4cedc7c6615ba31427a1757fc238159ca63080c33e39aabe0675499 +size 157764 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..a90c778a8867d987e2a40ecf7cff0c859920b109 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:054ef4683a61d9f4518b0e851cb62303d41d468dd58c3570bbf6df6cb8dca5ef +size 195444 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..ff85a8e3beaa687c9ffbc16d01d8e8d3fbab951f --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95254a80c8d989521992276b1dfd114e431bda02699ed25ac38bb90663ccaced +size 225507 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..06b2dd959f8249db3f8d31314bf4c5c38b8c4c10 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be7f567bdc78724e5aefe00110f0c96cf92edd103dad2b0e39ae84e547ef29e4 +size 202203 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..c690ef33cf6ec42978c9987773f4380ad8541d85 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:653c73488b153973cda4eea24cca004351b46dc381af7253e493789beb5c3e2e +size 193656 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..088ca2b13b8244ac9d0fbd0db73de6fcb9258ac0 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78acf59baef89a5547d9ef02d01c899c1f7b34f665deb00c7fc391fe2cdbfb9 +size 179399 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..632fc944ffd5fbf591334a4c8f405323cc086483 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9c0b1c9dd08aab09687c56f10ca744cad04c457f95f0cffe431d6ae479f4b70 +size 229646 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..7704411b0a481333fa9624dc66d8ecb0c76a9a22 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfa999c775e5df759986f247ce99bad266c1a190be4d8955d51d3ab1fdf02ad8 +size 209209 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..e7cedfd37c066f7aed957c67c9e31a0fc08db097 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b396359b928f0f007f1971b358f42d65c4981ea4f004605d857b01f60d23c054 +size 153676 diff --git a/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..32000206468ea38dcf1bcbb5ff5fe2f78b40cca4 --- /dev/null +++ b/examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3971ebbc73ac692d5b930a82ec15ec1b2735a4b7ea83679e089708ec40ef387 +size 208322 diff --git a/examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png b/examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..6a8a2eddc405c3425db8332ac875e3f920ec1424 --- /dev/null +++ b/examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a0695347d6cbb336520c30a496bd5c069dbe595d5a0dcd4bd9f2ec36b7e0624 +size 258412 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..7fb2ea6a461883d2d84983fc512bbe8ed48d734b --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9db6cb0329597bfe6850e784b2ed6811262ce7ddcf358262dd559aff4dbd5ca3 +size 1120127 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..98ccd91a3a19780ee3e336177fca5d1b4d7a5491 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de0deabc20bfa5778f94e2b930c9eaaf0d4a88435c4e3f65e33456341cce4215 +size 753475 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..b930c35fc54c749529ffe2a3a8cfeecd8da51ff1 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4499c1e908b72c1434591c559eaeb496ea3515b6432e5f980f0954aa4559b031 +size 1062747 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..7d6a10941ce37fea9244245fc1a85e7cfa083aea --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e25d4b5c2be1880ff078600ef94f199acedf5300347eb0cdead3cdc815173f08 +size 1030649 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..a2f5c7c609fa8cdcc9eee51b8ff528e7b7d7cda8 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a438833c5a7c615429ffc3ac9bc1f292c11bf89ba4f2cb36f8fcbcca9e85abae +size 766846 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..6d8ef7574619c2643f41fdafc4dac460f40629f9 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3937f9c7e14a06f33b489b22edb50f6ff4c9fc2b7226a338126c1f807a32042 +size 1029050 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..74f2d6fd7c08180be8f6ef5b42f85802bc718dc0 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ded2072afb77751e4a6fb95dce3d25759d24c56b7a9a130c5f024bdf088ad374 +size 872183 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..0e0d55c5c47b2baa517a27d7f5435457e1eb5728 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea53dd1dbd9c0b5f67ef3a74aba0693e42596887f31531ba16e4cec417dd872d +size 1209636 diff --git a/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..fc1849ebf4ac7091055b4e2f01a1da5d9b1fdce9 --- /dev/null +++ b/examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7de1667efcc1340decd49d6e22f19a072e838c49330aee0ed88cc6f8c1fa4cc3 +size 1335851 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..1114716e6582e347ae856388ea65abbaf271ebbe --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c54250e3ff5d49a62c2be37e459648235d68b63ea49ada869d73aef963047f2c +size 1464475 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..57e51ba3298e4ab84e217809e88fbe7a37c6d9d6 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51de91a056c4191c655867a9f822572d26fcb71868dd677a5df97f3b5fdf75a4 +size 926188 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..90d7b9388929fe8c70a23b6d281ff5bf220d0834 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28e6bc95ba0bfd8c787fd7d8f8b876aef35297192eca4ca69bb730b698f51e0e +size 1243502 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..10be62f2bda32b02799611e2a0e331f00f8e8ac5 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca770e211a0506bed904a0eb4ae618fa849e085f6727def0209d02674d03c0c +size 1412844 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..f3966fed69bf86a3dd65e363412ca0cc99ff76f9 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c35b21ef3ad757d55c4ece8d350727a4cb6e57f18f83d9f31f74e4f506d0354 +size 1365198 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..14b0c9d075d1034a0306eab29e48103a340e601f --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba83531396acb6a8a26e310562fc1f45ef4dbea19615092105c0623024380b48 +size 1038758 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..ade03de10c20dca167522f72a8f2360b80c4d92c --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b7df6459fd98a173a81b62270ed639516516c4981ba56f45dec05bb950ab527 +size 1373683 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..62e344f21822a749cab4e4216890f9676c8893a6 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56a9e614cb6e49fe82b8e01e5c869b3550cc14276b1834e7496aadc5c7ee8794 +size 1231934 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..e8a2c66be42397b629b720d266238cb0c3906d93 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a24b26ab679edc1edb85c70d8b24d5c287e094c493eb01de11dc210d144c0283 +size 1145474 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..a52a37eaaa7f71d4ec99c352666433e07669bba2 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee880f4794fecdf4a354f5fc57cfc494a25f9978235fb22218e89c2fa28bfb7c +size 1283641 diff --git a/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..9c5008762f54856eb31ab2cb316ec1d0f6ff9113 --- /dev/null +++ b/examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e060087646ebb70b4606a3b16162aa2acd4f8e32807852706b2a0aaa49c4b29 +size 1198433 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..ea943613708d6ecbfb4ff636f6bd01461cb22877 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71678a603b4d955e86e528ccd6b6e1c8bd0733452bc7101f0b8e06c5b9e866a2 +size 866264 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..f539ed82f9983f31d0127fc402ebd92445badec2 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7dcd8f083078a949ab7f9ac623759bc4381c1794e6a4b477d31aa504ceecc4a +size 1103828 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..1aa0ccd2ac498a203a1a3dae6db8f942c7d774e5 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4d38a48ac02dfa84931f01bcddb74d7aad818f6efb9737e3738509ddbbada38 +size 890980 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..eedb5050ccf3ceeb2d7bb63e42340be740cb601d --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c820c5895122f9c549e3d07dada3a9a9f75f797ca19e8cae3e27aef3635623f +size 1076917 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..9fb3b381b84d22611d6b1d2f1b9777d58f06b5d0 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6574275a66b1ca8fdb5e7b803408e2ab6a305ad1ad578a86dcf706fd37b3310 +size 915655 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..842402ad875a6ab674f3c6d7a54481fcddabe017 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc01596dfe49806e8069fee5f1e6da4096c63d809538c9260d57a09e5768feb0 +size 826766 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..3749c6d7493f1bddfe36741377e2f62f6be0f655 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5cd1366833fa8de2610f722e606324f2cabef9b33adf4a00a2b8f6d38c8885a +size 955368 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..cd35a98df672b903381dce1f28c0b56a378237f2 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a5f86694982f5e732ad701869105160e9abcaf74f9c5b61fef09f2eb72967b4 +size 790170 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..0142f76e5b6c588d0e3010518d16af795c5ce3f6 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:867b8468c2d379599912a400189b474fcd71a3c2650211018976a50d8b0a7bf2 +size 981022 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..b59aa2653f8746a7099c1fe4c1fe6b9622802b79 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72fab64daf6889cddfb7adc1bb0a208339cdb4d55c0e5ab5b926451e6ede8630 +size 1014963 diff --git a/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..de0f2acd428b7bcbdc6f62df52feef7763079d65 --- /dev/null +++ b/examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:489b072ada0948c0e78a70da65d0cf00b7bc5ae6d6732e76472b88a91ca4e3a5 +size 928942 diff --git a/examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png b/examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png new file mode 100644 index 0000000000000000000000000000000000000000..821d7655e12763c2d28156fe25876637fa8bad8f --- /dev/null +++ b/examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5d8c7b7a972d2ec492b81b8671da8520641134af7171028071a898fe94fa26f +size 1359435 diff --git a/models/__pycache__/get_model.cpython-36.pyc b/models/__pycache__/get_model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c9413ed94ca7c8aa27824f8737ac80954459fa Binary files /dev/null and b/models/__pycache__/get_model.cpython-36.pyc differ diff --git a/models/get_model.py b/models/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1ca05104b3e5eb6374f8d13fa2d9c4eefdd3ea --- /dev/null +++ b/models/get_model.py @@ -0,0 +1,95 @@ +from models.segmentation_models.cen import ChannelExchangingNetwork +from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50 +from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion +from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask +from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency +from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency +from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency +from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch +from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency +from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix +from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency +from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion +from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP +from models.segmentation_models.refinenet import MyRefineNet +from models.segmentation_models.segformer.segformer import SegFormer +from models.segmentation_models.tokenfusion.segformer import WeTr +from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask +from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency +from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency +from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch +from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork +from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop + +def get_model(args, **kwargs): + if args.seg_model == "dlv3p": + if args.base_model == "r18": + return DeepLabV3p_r18(args.num_classes, args) + elif args.base_model == "r50": + return DeepLabV3p_r50(args.num_classes, args) + elif args.base_model == "r101": + return DeepLabV3p_r101(args.num_classes, args) + else: + raise Exception(f"{args.base_model} not configured") + elif args.seg_model == 'refinenet': + if args.base_model == 'r18': + return MyRefineNet(num_layers = 18, num_classes = args.num_classes) + if args.base_model == 'r50': + return MyRefineNet(num_layers = 50, num_classes = args.num_classes) + if args.base_model == 'r101': + return MyRefineNet(num_layers = 101, num_classes = args.num_classes) + elif args.seg_model == 'cen': + if args.base_model == 'r18': + return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) + if args.base_model == 'r50': + return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) + if args.base_model == 'r101': + return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) + elif args.seg_model == 'segformer': + return SegFormer(args.base_model, args, num_classes= args.num_classes) + elif args.seg_model == 'tokenfusion': + return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) + elif args.seg_model == 'randomfusion': + return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == 'randomfusiondmlp': + return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == 'randomexchangepredconsistency': + return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) + elif args.seg_model == 'linearfusion': + pretrained = True + if "pretrained_init" in args: + pretrained = args.pretrained_init + print("Using pretrained SegFormer? ", pretrained) + return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) + elif args.seg_model == 'linearfusionconsistency': + return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) + elif args.seg_model == 'linearfusionmaskedcons': + pretrained = True + if "pretrained_init" in args: + pretrained = args.pretrained_init + print("Using pretrained SegFormer? ", pretrained) + return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) + elif args.seg_model == 'linearfusionmaskedconsmixbatch': + return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == 'linearfusionsepdecodermaskedcons': + return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == 'linearfusionmaemaskedcons': + return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == 'tokenfusionmaskedcons': + return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) + elif args.seg_model == 'tokenfusionmaskedconsmixbatch': + return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) + elif args.seg_model == 'tokenfusionbothmask': + return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs) + elif args.seg_model == "linearfusebothmask": + return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == "linearfusiontokenmix": + return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent) + elif args.seg_model == "tokenfusionmaemaskedcons": + return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) + elif args.seg_model == "unifiedrepresentationnetwork": + return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes) + elif args.seg_model == "unifiedrepresentationnetworkmoddrop": + return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes) + else: + raise Exception(f"{args.seg_model} not configured") \ No newline at end of file diff --git a/models/segmentation_models/__init__.py b/models/segmentation_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/segmentation_models/__pycache__/__init__.cpython-36.pyc b/models/segmentation_models/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f95b01c8be93b6907896fa20372106b396ff4a0 Binary files /dev/null and b/models/segmentation_models/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/segmentation_models/__pycache__/__init__.cpython-38.pyc b/models/segmentation_models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf0e01075ed0a714dfcc52e0bdb542f8f76c3e2f Binary files /dev/null and b/models/segmentation_models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/segmentation_models/__pycache__/cen.cpython-36.pyc b/models/segmentation_models/__pycache__/cen.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..232eecbf1cb6c473257b28e370621c7cc43756e9 Binary files /dev/null and b/models/segmentation_models/__pycache__/cen.cpython-36.pyc differ diff --git a/models/segmentation_models/__pycache__/deeplabv3p.cpython-36.pyc b/models/segmentation_models/__pycache__/deeplabv3p.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e932a5b0b8241eb5267a6716cfea81284fde7270 Binary files /dev/null and b/models/segmentation_models/__pycache__/deeplabv3p.cpython-36.pyc differ diff --git a/models/segmentation_models/__pycache__/refinenet.cpython-36.pyc b/models/segmentation_models/__pycache__/refinenet.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a77ff2748eeb235db3ef855bc2d9e33e9d125c59 Binary files /dev/null and b/models/segmentation_models/__pycache__/refinenet.cpython-36.pyc differ diff --git a/models/segmentation_models/cen.py b/models/segmentation_models/cen.py new file mode 100644 index 0000000000000000000000000000000000000000..73b112518c855f959deb7b0618238b64b312bc28 --- /dev/null +++ b/models/segmentation_models/cen.py @@ -0,0 +1,627 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import SyncBatchNorm as BatchNorm2d +import re +import os, sys +# from six import moves + +class Exchange(nn.Module): + def __init__(self): + super(Exchange, self).__init__() + + def forward(self, x, bn, bn_threshold): + bn1, bn2 = bn[0].weight.abs(), bn[1].weight.abs() + x1, x2 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x1[:, bn1 >= bn_threshold] = x[0][:, bn1 >= bn_threshold] + x1[:, bn1 < bn_threshold] = x[1][:, bn1 < bn_threshold] + x2[:, bn2 >= bn_threshold] = x[1][:, bn2 >= bn_threshold] + x2[:, bn2 < bn_threshold] = x[0][:, bn2 < bn_threshold] + return [x1, x2] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class BatchNorm2dParallel(nn.Module): + def __init__(self, num_features, num_parallel): + super(BatchNorm2dParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'bn_' + str(i), BatchNorm2d(num_features)) + + def forward(self, x_parallel): + return [getattr(self, 'bn_' + str(i))(x) for i, x in enumerate(x_parallel)] + +class ChannelExchangingNetwork(nn.Module): + def __init__(self, num_layers, num_classes, num_parallel, l1_lambda, bn_threshold): + super(ChannelExchangingNetwork, self).__init__() + self.model = refinenet(num_layers, num_classes, num_parallel, bn_threshold) + self.model = model_init(self.model, num_layers, num_parallel, imagenet=True) #Only initializes the encoder + self.l1_lambda = l1_lambda + + def get_slim_params(self): + slim_params = [] + for name, param in self.model.named_parameters(): + if param.requires_grad and name.endswith('weight') and 'bn2' in name: + if len(slim_params) % 2 == 0: + slim_params.append(param[:len(param) // 2]) + else: + slim_params.append(param[len(param) // 2:]) + return slim_params + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + pred = self.model(data) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + l1_loss = self.l1_lambda * self.get_l1_loss(data[0].get_device()) + sup_loss = self.get_sup_loss(pred, gt, criterion) + return pred, sup_loss + l1_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + def get_params(self): + self.slim_params = self.get_slim_params() #Doing it here and not in __init__ because first the model should be put in appropriate device before accumulating slim_params + # enc_params, dec_params = [], [] + # for name, param in self.model.named_parameters(): + # if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): + # enc_params.append(param) + # else: + # dec_params.append(param) + # return enc_params, dec_params + param_groups = [[], [], []] + for name, param in self.model.named_parameters(): + if "norm" in name: + param_groups[1].append(param) + elif bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): + param_groups[0].append(param) + else: + param_groups[2].append(param) + return param_groups + + def get_l1_loss(self, device): + L1_norm = sum([L1_penalty(m, device) for m in self.slim_params]) + if L1_norm > 0: + return L1_norm.to(device) + else: + return torch.tensor(0).to(device) + + +"""RefineNet-LightWeight + +RefineNet-LigthWeight PyTorch for non-commercial purposes + +Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +models_urls = { + '101_voc' : 'https://cloudstor.aarnet.edu.au/plus/s/Owmttk9bdPROwc6/download', + + '18_imagenet' : 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + '50_imagenet' : 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + '101_imagenet': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + '152_imagenet': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + +bottleneck_idx = 0 +save_idx = 0 + + +def conv3x3(in_planes, out_planes, stride=1, bias=False): + "3x3 convolution with padding" + return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=bias)) + + +def conv1x1(in_planes, out_planes, stride=1, bias=False): + "1x1 convolution" + return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=1, + stride=stride, padding=0, bias=bias)) + + +class CRPBlock(nn.Module): + def __init__(self, in_planes, out_planes, num_stages, num_parallel): + super(CRPBlock, self).__init__() + for i in range(num_stages): + setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), + conv3x3(in_planes if (i == 0) else out_planes, out_planes)) + self.stride = 1 + self.num_stages = num_stages + self.num_parallel = num_parallel + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=5, stride=1, padding=2)) + + def forward(self, x): + top = x + for i in range(self.num_stages): + top = self.maxpool(top) + top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) + x = [x[l] + top[l] for l in range(self.num_parallel)] + return x + + +stages_suffixes = {0 : '_conv', 1 : '_conv_relu_varout_dimred'} + +class RCUBlock(nn.Module): + def __init__(self, in_planes, out_planes, num_blocks, num_stages, num_parallel): + super(RCUBlock, self).__init__() + for i in range(num_blocks): + for j in range(num_stages): + setattr(self, '{}{}'.format(i + 1, stages_suffixes[j]), + conv3x3(in_planes if (i == 0) and (j == 0) else out_planes, + out_planes, bias=(j == 0))) + self.stride = 1 + self.num_blocks = num_blocks + self.num_stages = num_stages + self.num_parallel = num_parallel + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + + def forward(self, x): + for i in range(self.num_blocks): + residual = x + for j in range(self.num_stages): + x = self.relu(x) + x = getattr(self, '{}{}'.format(i + 1, stages_suffixes[j]))(x) + x = [x[l] + residual[l] for l in range(self.num_parallel)] + return x + + +class BasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2dParallel(planes, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2dParallel(planes, num_parallel) + self.num_parallel = num_parallel + self.downsample = downsample + self.stride = stride + + self.exchange = Exchange() + self.bn_threshold = bn_threshold + self.bn2_list = [] + for module in self.bn2.modules(): + if isinstance(module, BatchNorm2d): + self.bn2_list.append(module) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + if len(x) > 1: + out = self.exchange(out, self.bn2_list, self.bn_threshold) + + if self.downsample is not None: + residual = self.downsample(x) + + out = [out[l] + residual[l] for l in range(self.num_parallel)] + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = BatchNorm2dParallel(planes, num_parallel) + self.conv2 = conv3x3(planes, planes, stride=stride) + self.bn2 = BatchNorm2dParallel(planes, num_parallel) + self.conv3 = conv1x1(planes, planes * 4) + self.bn3 = BatchNorm2dParallel(planes * 4, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.num_parallel = num_parallel + self.downsample = downsample + self.stride = stride + + self.exchange = Exchange() + self.bn_threshold = bn_threshold + self.bn2_list = [] + for module in self.bn2.modules(): + if isinstance(module, BatchNorm2d): + self.bn2_list.append(module) + + def forward(self, x): + residual = x + out = x + + out = self.conv1(out) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + if len(x) > 1: + out = self.exchange(out, self.bn2_list, self.bn_threshold) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = [out[l] + residual[l] for l in range(self.num_parallel)] + out = self.relu(out) + + return out + + +class RefineNet(nn.Module): + def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): + self.inplanes = 64 + self.num_parallel = num_parallel + super(RefineNet, self).__init__() + self.dropout = ModuleParallel(nn.Dropout(p=0.5)) + self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) + self.bn1 = BatchNorm2dParallel(64, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) + self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) + + self.p_ims1d2_outl1_dimred = conv3x3(2048, 512) + self.adapt_stage1_b = self._make_rcu(512, 512, 2, 2) + self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) + self.mflow_conv_g1_b = self._make_rcu(512, 512, 3, 2) + self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(512, 256) + + self.p_ims1d2_outl2_dimred = conv3x3(1024, 256) + self.adapt_stage2_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage2_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g2_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(256, 256) + + self.p_ims1d2_outl3_dimred = conv3x3(512, 256) + self.adapt_stage3_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage3_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g3_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(256, 256) + + self.p_ims1d2_outl4_dimred = conv3x3(256, 256) + self.adapt_stage4_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage4_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g4_b = self._make_rcu(256, 256, 3, 2) + + self.clf_conv = conv3x3(256, num_classes, bias=True) + self.alpha = nn.Parameter(torch.ones(num_parallel, requires_grad=True)) + # self.alpha = nn.Parameter(torch.ones([1, num_parallel, 157, 157], requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def _make_crp(self, in_planes, out_planes, num_stages): + layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): + layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride=stride), + BatchNorm2dParallel(planes * block.expansion, self.num_parallel) + ) + + layers = [] + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + + l4 = self.dropout(l4) + l3 = self.dropout(l3) + + x4 = self.p_ims1d2_outl1_dimred(l4) + x4 = self.adapt_stage1_b(x4) + x4 = self.relu(x4) + x4 = self.mflow_conv_g1_pool(x4) + x4 = self.mflow_conv_g1_b(x4) + x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) + x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] + + x3 = self.p_ims1d2_outl2_dimred(l3) + x3 = self.adapt_stage2_b(x3) + x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) + x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] + x3 = self.relu(x3) + x3 = self.mflow_conv_g2_pool(x3) + x3 = self.mflow_conv_g2_b(x3) + x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) + x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] + + x2 = self.p_ims1d2_outl3_dimred(l2) + x2 = self.adapt_stage3_b(x2) + x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) + x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] + x2 = self.relu(x2) + x2 = self.mflow_conv_g3_pool(x2) + x2 = self.mflow_conv_g3_b(x2) + x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) + x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] + + x1 = self.p_ims1d2_outl4_dimred(l1) + x1 = self.adapt_stage4_b(x1) + x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) + x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] + x1 = self.relu(x1) + x1 = self.mflow_conv_g4_pool(x1) + x1 = self.mflow_conv_g4_b(x1) + x1 = self.dropout(x1) + + out = self.clf_conv(x1) + ens = 0 + alpha_soft = F.softmax(self.alpha, dim = 0) + for l in range(self.num_parallel): + ens += alpha_soft[l] * out[l].detach() + # alpha_soft = F.softmax(self.alpha, dim=1) + # for l in range(self.num_parallel): + # print(out[l].shape, l) + # ens += alpha_soft[:, l].unsqueeze(1) * out[l].detach() + out.append(ens) + # return out, alpha_soft + return out + + +class RefineNet_Resnet18(nn.Module): + def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): + self.inplanes = 64 + self.num_parallel = num_parallel + super(RefineNet_Resnet18, self).__init__() + self.dropout = ModuleParallel(nn.Dropout(p=0.5)) + self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) + self.bn1 = BatchNorm2dParallel(64, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) + self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) + + self.p_ims1d2_outl1_dimred = conv3x3(512, 256) + self.adapt_stage1_b = self._make_rcu(256, 256, 2, 2) + self.mflow_conv_g1_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g1_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(256, 64) + + self.p_ims1d2_outl2_dimred = conv3x3(256, 64) + self.adapt_stage2_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage2_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g2_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g2_b = self._make_rcu(64, 64, 3, 2) + self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(64, 64) + + self.p_ims1d2_outl3_dimred = conv3x3(128, 64) + self.adapt_stage3_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage3_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g3_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g3_b = self._make_rcu(64, 64, 3, 2) + self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(64, 64) + + self.p_ims1d2_outl4_dimred = conv3x3(64, 64) + self.adapt_stage4_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage4_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g4_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g4_b = self._make_rcu(64, 64, 3, 2) + + self.clf_conv = conv3x3(64, num_classes, bias=True) + self.alpha = nn.Parameter(torch.ones(num_parallel, requires_grad=True)) + # self.alpha = nn.Parameter(torch.ones([1, num_parallel, 157, 157], requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def _make_crp(self, in_planes, out_planes, num_stages): + layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): + layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride=stride), + BatchNorm2dParallel(planes * block.expansion, self.num_parallel) + ) + + layers = [] + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + + l4 = self.dropout(l4) + l3 = self.dropout(l3) + + x4 = self.p_ims1d2_outl1_dimred(l4) + x4 = self.adapt_stage1_b(x4) + x4 = self.relu(x4) + x4 = self.mflow_conv_g1_pool(x4) + x4 = self.mflow_conv_g1_b(x4) + x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) + x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] + + x3 = self.p_ims1d2_outl2_dimred(l3) + x3 = self.adapt_stage2_b(x3) + x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) + x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] + x3 = self.relu(x3) + x3 = self.mflow_conv_g2_pool(x3) + x3 = self.mflow_conv_g2_b(x3) + x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) + x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] + + x2 = self.p_ims1d2_outl3_dimred(l2) + x2 = self.adapt_stage3_b(x2) + x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) + x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] + x2 = self.relu(x2) + x2 = self.mflow_conv_g3_pool(x2) + x2 = self.mflow_conv_g3_b(x2) + x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) + x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] + + x1 = self.p_ims1d2_outl4_dimred(l1) + x1 = self.adapt_stage4_b(x1) + x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) + x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] + x1 = self.relu(x1) + x1 = self.mflow_conv_g4_pool(x1) + x1 = self.mflow_conv_g4_b(x1) + x1 = self.dropout(x1) + + out = self.clf_conv(x1) + ens = 0 + alpha_soft = F.softmax(self.alpha, dim = 0) + for l in range(self.num_parallel): + ens += alpha_soft[l] * out[l].detach() + # alpha_soft = F.softmax(self.alpha, dim=1) + # for l in range(self.num_parallel): + # print(out[l].shape, l) + # ens += alpha_soft[:, l].unsqueeze(1) * out[l].detach() + out.append(ens) + return out, alpha_soft + + +def refinenet(num_layers, num_classes, num_parallel, bn_threshold): + refinnetClass = RefineNet + if int(num_layers) == 18: + layers = [2, 2, 2, 2] + block = BasicBlock + refinnetClass = RefineNet_Resnet18 + elif int(num_layers) == 50: + layers = [3, 4, 6, 3] + block = Bottleneck + elif int(num_layers) == 101: + layers = [3, 4, 23, 3] + block = Bottleneck + elif int(num_layers) == 152: + layers = [3, 8, 36, 3] + block = Bottleneck + else: + print('invalid num_layers') + + model = refinnetClass(block, layers, num_parallel, num_classes, bn_threshold) + return model + +def maybe_download(model_name, model_url, model_dir=None, map_location=None): + if model_dir is None: + torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) + model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = '{}.pth.tar'.format(model_name) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + # url = model_url + # sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + # moves.urllib.request.urlretrieve(url, cached_file) + raise Exception('cached file not found, maybe_download failed') + return torch.load(cached_file, map_location=map_location) + +def model_init(model, num_layers, num_parallel, imagenet=False): + if imagenet: + key = str(num_layers) + '_imagenet' + url = models_urls[key] + state_dict = maybe_download(key, url) + model_dict = expand_model_dict(model.state_dict(), state_dict, num_parallel) + model.load_state_dict(model_dict, strict=True) + return model + +def expand_model_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + bn = '.bn_%d' % i + replace = True if bn in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(bn, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) \ No newline at end of file diff --git a/models/segmentation_models/deeplabv3p.py b/models/segmentation_models/deeplabv3p.py new file mode 100644 index 0000000000000000000000000000000000000000..6d825d775fcf20d7a0c3c8df623c79d414ff906c --- /dev/null +++ b/models/segmentation_models/deeplabv3p.py @@ -0,0 +1,413 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch.nn import BatchNorm2d +from torch.nn import SyncBatchNorm as BatchNorm2d +from functools import partial +import re +from models.base_models.resnet import resnet101, resnet18, resnet50 +from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp + +class DeepLabV3p_r18(nn.Module): + def __init__(self, num_classes, config): + super(DeepLabV3p_r18, self).__init__() + self.norm_layer = BatchNorm2d + self.backbone = resnet18(config.pretrained_model_r18, norm_layer=self.norm_layer, + bn_eps=config.bn_eps, + bn_momentum=config.bn_momentum, + deep_stem=False, stem_width=64) + self.dilate = 2 + for m in self.backbone.layer4.children(): + m.apply(partial(self._nostride_dilate, dilate=self.dilate)) + self.dilate *= 2 + + self.head = Head('r18', num_classes, self.norm_layer, config.bn_momentum) + self.business_layer = [] + self.business_layer.append(self.head) + + self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) + self.business_layer.append(self.classifier) + init_weight(self.business_layer, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + init_weight(self.classifier, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + data = data[0] #rgb is the first element in the list + blocks = self.backbone(data) + v3plus_feature = self.head(blocks) #(b, c, h, w) + b, c, h, w = v3plus_feature.shape + + pred = self.classifier(v3plus_feature) + + b, c, h, w = data.shape + pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) + if not self.training: #return pred for evaluation + return pred + else: + if get_sup_loss: + return pred, self.get_sup_loss(pred, gt, criterion) + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + return criterion(pred, gt) + + # @staticmethod + def _nostride_dilate(self, m, dilate): + if isinstance(m, nn.Conv2d): + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def get_params(self): + param_groups = [[], [], []] + enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) + param_groups[0].extend(enc) + param_groups[1].extend(enc_no_decay) + dec, dec_no_decay = group_weight(self.head, self.norm_layer) + param_groups[2].extend(dec) + param_groups[1].extend(dec_no_decay) + classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) + param_groups[2].extend(classifier) + param_groups[1].extend(classifier_no_decay) + return param_groups + +class DeepLabV3p_r50(nn.Module): + def __init__(self, num_classes, config): + super(DeepLabV3p_r50, self).__init__() + self.norm_layer = BatchNorm2d + self.backbone = resnet50(config.pretrained_model_r50, norm_layer=self.norm_layer, + bn_eps=config.bn_eps, + bn_momentum=config.bn_momentum, + deep_stem=True, stem_width=64) + self.dilate = 2 + for m in self.backbone.layer4.children(): + m.apply(partial(self._nostride_dilate, dilate=self.dilate)) + self.dilate *= 2 + + self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) + self.business_layer = [] + self.business_layer.append(self.head) + + self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) + self.business_layer.append(self.classifier) + init_weight(self.business_layer, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + init_weight(self.classifier, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + data = data[0] #rgb is the first element in the list + blocks = self.backbone(data) + v3plus_feature = self.head(blocks) #(b, c, h, w) + b, c, h, w = v3plus_feature.shape + + pred = self.classifier(v3plus_feature) + + b, c, h, w = data.shape + pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) + if not self.training: #return pred for evaluation + return pred + else: + if get_sup_loss: + return pred, self.get_sup_loss(pred, gt, criterion) + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + return criterion(pred, gt) + + def get_params(self): + param_groups = [[], [], []] + enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) + param_groups[0].extend(enc) + param_groups[1].extend(enc_no_decay) + dec, dec_no_decay = group_weight(self.head, self.norm_layer) + param_groups[2].extend(dec) + param_groups[1].extend(dec_no_decay) + classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) + param_groups[2].extend(classifier) + param_groups[1].extend(classifier_no_decay) + return param_groups + + # @staticmethod + def _nostride_dilate(self, m, dilate): + if isinstance(m, nn.Conv2d): + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + +class DeepLabV3p_r101(nn.Module): + def __init__(self, num_classes, config): + super(DeepLabV3p_r101, self).__init__() + self.norm_layer = BatchNorm2d + self.backbone = resnet101(config.pretrained_model_r101, norm_layer=self.norm_layer, + bn_eps=config.bn_eps, + bn_momentum=config.bn_momentum, + deep_stem=True, stem_width=64) + self.dilate = 2 + for m in self.backbone.layer4.children(): + m.apply(partial(self._nostride_dilate, dilate=self.dilate)) + self.dilate *= 2 + + self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) + self.business_layer = [] + self.business_layer.append(self.head) + + self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) + self.business_layer.append(self.classifier) + init_weight(self.business_layer, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + init_weight(self.classifier, nn.init.kaiming_normal_, + BatchNorm2d, config.bn_eps, config.bn_momentum, + mode='fan_in', nonlinearity='relu') + + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + data = data[0] #rgb is the first element in the list + blocks = self.backbone(data) + v3plus_feature = self.head(blocks) #(b, c, h, w) + b, c, h, w = v3plus_feature.shape + + pred = self.classifier(v3plus_feature) + + b, c, h, w = data.shape + pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) + if not self.training: #return pred for evaluation + return pred + else: + if get_sup_loss: + return pred, self.get_sup_loss(pred, gt, criterion) + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + return criterion(pred, gt) + + def get_params(self): + param_groups = [[], [], []] + enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) + param_groups[0].extend(enc) + param_groups[1].extend(enc_no_decay) + dec, dec_no_decay = group_weight(self.head, self.norm_layer) + param_groups[2].extend(dec) + param_groups[1].extend(dec_no_decay) + classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) + param_groups[2].extend(classifier) + param_groups[1].extend(classifier_no_decay) + return param_groups + + # @staticmethod + def _nostride_dilate(self, m, dilate): + if isinstance(m, nn.Conv2d): + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + +class ASPP(nn.Module): + def __init__(self, + in_channels, + out_channels, + dilation_rates=(12, 24, 36), + hidden_channels=256, + norm_act=nn.BatchNorm2d, + pooling_size=None): + super(ASPP, self).__init__() + self.pooling_size = pooling_size + + self.map_convs = nn.ModuleList([ + nn.Conv2d(in_channels, hidden_channels, 1, bias=False), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[0], + padding=dilation_rates[0]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[1], + padding=dilation_rates[1]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[2], + padding=dilation_rates[2]) + ]) + self.map_bn = norm_act(hidden_channels * 4) + + self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) + self.global_pooling_bn = norm_act(hidden_channels) + + self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) + self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) + self.red_bn = norm_act(out_channels) + + self.leak_relu = nn.LeakyReLU() + + def forward(self, x): + # Map convolutions + out = torch.cat([m(x) for m in self.map_convs], dim=1) + out = self.map_bn(out) + out = self.leak_relu(out) # add activation layer + out = self.red_conv(out) + + # Global pooling + pool = self._global_pooling(x) + pool = self.global_pooling_conv(pool) + pool = self.global_pooling_bn(pool) + + pool = self.leak_relu(pool) # add activation layer + + pool = self.pool_red_conv(pool) + if self.training or self.pooling_size is None: + pool = pool.repeat(1, 1, x.size(2), x.size(3)) + + out += pool + out = self.red_bn(out) + out = self.leak_relu(out) # add activation layer + return out + + def _global_pooling(self, x): + pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) + pool = pool.view(x.size(0), x.size(1), 1, 1) + return pool + + +class Head(nn.Module): + def __init__(self, base_model, classify_classes, norm_act=nn.BatchNorm2d, bn_momentum=0.0003): + super(Head, self).__init__() + + self.classify_classes = classify_classes + if base_model == 'r18': + self.aspp = ASPP(512, 256, [6, 12, 18], norm_act=norm_act) + + self.reduce = nn.Sequential( + nn.Conv2d(64, 48, 1, bias=False), + norm_act(48, momentum=bn_momentum), + nn.ReLU(), + ) + elif base_model == 'r50': + self.aspp = ASPP(2048, 256, [6, 12, 18], norm_act=norm_act) + self.reduce = nn.Sequential( + nn.Conv2d(256, 48, 1, bias=False), + norm_act(48, momentum=bn_momentum), + nn.ReLU(), + ) + else: + raise Exception(f"Head not implemented for {base_model}") + + + + self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), + norm_act(256, momentum=bn_momentum), + nn.ReLU(), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), + norm_act(256, momentum=bn_momentum), + nn.ReLU(), + ) + + def forward(self, f_list): + f = f_list[-1] + f = self.aspp(f) + + low_level_features = f_list[0] + low_h, low_w = low_level_features.size(2), low_level_features.size(3) + low_level_features = self.reduce(low_level_features) + + f = F.interpolate(f, size=(low_h, low_w), mode='bilinear', align_corners=True) + f = torch.cat((f, low_level_features), dim=1) + f = self.last_conv(f) + + return f + + +def group_weight(module, norm_layer): + group_decay = [] + group_no_decay = [] + for m in module.modules(): + if isinstance(m, nn.Linear): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, Conv2_5D_depth): + group_decay.append(m.weight_0) + group_decay.append(m.weight_1) + group_decay.append(m.weight_2) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, Conv2_5D_disp): + group_decay.append(m.weight_0) + group_decay.append(m.weight_1) + group_decay.append(m.weight_2) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ + or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): + if m.weight is not None: + group_no_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, nn.Parameter): + group_decay.append(m) + elif isinstance(m, nn.Embedding): + group_decay.append(m) + assert len(list(module.parameters())) == len(group_decay) + len( + group_no_decay) + return group_decay, group_no_decay + +def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs): + for name, m in feature.named_modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + conv_init(m.weight, **kwargs) + elif isinstance(m, Conv2_5D_depth): + conv_init(m.weight_0, **kwargs) + conv_init(m.weight_1, **kwargs) + conv_init(m.weight_2, **kwargs) + elif isinstance(m, Conv2_5D_disp): + conv_init(m.weight_0, **kwargs) + conv_init(m.weight_1, **kwargs) + conv_init(m.weight_2, **kwargs) + elif isinstance(m, norm_layer): + m.eps = bn_eps + m.momentum = bn_momentum + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs): + if isinstance(module_list, list): + for feature in module_list: + __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs) + else: + __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs) \ No newline at end of file diff --git a/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f5fb2927e65eb3e75d7d829875e05a21331be0 Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-38.pyc b/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc61ac3166df06942a51c1851a1a32e5603e049e Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/mix_transformer.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfuse/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfuse/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..836aca7e47cfad2023ab443882ea995d0b2c1da7 Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/modules.cpython-38.pyc b/models/segmentation_models/linearfuse/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53b3e13fccbf62784a17a55c4c7981733099dbbd Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/modules.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12af665e827a19033db3737ec6cbd25abb9841e6 Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-38.pyc b/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..500230fd7903e3e1dbb65350caf856bc542b4dcf Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/segformer.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfuse/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfuse/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfuse/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfuse/mix_transformer.py b/models/segmentation_models/linearfuse/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6913d931b1b4f286a9d5d61a2013144503e6a81d --- /dev/null +++ b/models/segmentation_models/linearfuse/mix_transformer.py @@ -0,0 +1,474 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/linearfuse/modules.py b/models/segmentation_models/linearfuse/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfuse/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfuse/segformer.py b/models/segmentation_models/linearfuse/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..956889666cc16d6656756c971c0770c5007734e0 --- /dev/null +++ b/models/segmentation_models/linearfuse/segformer.py @@ -0,0 +1,170 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class WeTrLinearFusion(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x = self.encoder(data) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..555159be70c4358d1b054e761904159a30854c61 Binary files /dev/null and b/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusebothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusebothmask/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusebothmask/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4097dcb50a61fc343725343d4f431be1c485eba Binary files /dev/null and b/models/segmentation_models/linearfusebothmask/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusebothmask/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusebothmask/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb37b855082b31cac02f2022409ffc364806ff38 Binary files /dev/null and b/models/segmentation_models/linearfusebothmask/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusebothmask/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusebothmask/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusebothmask/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusebothmask/mix_transformer.py b/models/segmentation_models/linearfusebothmask/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0354427ffd2c14b1dd3ae1e1cc8b2eb25ac60f42 --- /dev/null +++ b/models/segmentation_models/linearfusebothmask/mix_transformer.py @@ -0,0 +1,537 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_mean(self, x): + # print(x.shape) + avg = torch.mean(x, dim = 1) + avg = avg.clone().detach().requires_grad_(False) + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - self.masking_ratio)) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + + # keep the first subset + ids_mask = ids_shuffle[:, len_keep:] + # for i in range(N): + # x[i][ids_mask[i]] = avg[i] + # return x + avg = avg.unsqueeze(dim = 1) + avg = avg.repeat(1, L, 1) + mask = ids_mask.unsqueeze(dim = 2) + mask = mask.repeat(1, 1, D) + masked = torch.scatter(x, dim = 1, index = mask, src = avg) # self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1 + # self.printcheck(x[0], masked[0], avg[0]) + return masked + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + + def printcheck(self, x, masked, avg): + L, D = x.shape + same = 0 + avgsame = 0 + for i in range(L): + if (x[i] == masked[i]).all(): + same += 1 + # else: + # print(i, x[i]) + if (masked[i].data == avg[i].data).all(): + avgsame += 1 + print(same, avgsame) + return + + def forward(self, x, mask = False, range_batches_to_mask = None): + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if mask: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[0][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[0][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[1][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[1][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, mask, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + # stage 1 + x, H, W = self.patch_embed1(x, mask = mask, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, mask = False) + # x, H, W = self.patch_embed2(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, mask = False) + # x, H, W = self.patch_embed3(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, mask = False) + # x, H, W = self.patch_embed4(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x, mask, range_batches_to_mask): + x = self.forward_features(x, mask, range_batches_to_mask) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/linearfusebothmask/modules.py b/models/segmentation_models/linearfusebothmask/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusebothmask/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusebothmask/segformer.py b/models/segmentation_models/linearfusebothmask/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..be59ec5afe1b1cdc862edea579b172e1b6f77888 --- /dev/null +++ b/models/segmentation_models/linearfusebothmask/segformer.py @@ -0,0 +1,172 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class LinearFusionBothMask(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio, masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + print("Load pretrained weights from " + config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + else: + print("Train from scratch") + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x = self.encoder(data, mask = mask, range_batches_to_mask = range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusecons/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusecons/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adea76e80d3769d6efcc8721c0cc0299eddc991b Binary files /dev/null and b/models/segmentation_models/linearfusecons/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusecons/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusecons/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusecons/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusecons/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusecons/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d52bb81d6369c2287d5af7c2df17c172aef8c2 Binary files /dev/null and b/models/segmentation_models/linearfusecons/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusecons/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusecons/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db19460cfc7c4985d4d41b2febee2d3222a20616 Binary files /dev/null and b/models/segmentation_models/linearfusecons/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusecons/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusecons/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusecons/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusecons/mix_transformer.py b/models/segmentation_models/linearfusecons/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6913d931b1b4f286a9d5d61a2013144503e6a81d --- /dev/null +++ b/models/segmentation_models/linearfusecons/mix_transformer.py @@ -0,0 +1,474 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/linearfusecons/modules.py b/models/segmentation_models/linearfusecons/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusecons/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusecons/segformer.py b/models/segmentation_models/linearfusecons/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc34f8861437a5c8523691cbb160bcfb5003bb3 --- /dev/null +++ b/models/segmentation_models/linearfusecons/segformer.py @@ -0,0 +1,187 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.seg_opr.loss_func import JSD +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class LinearFusionConsistency(nn.Module): + def __init__(self, backbone, config, cons_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.cons_lambda = cons_lambda + self.cons_loss = JSD() + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x = self.encoder(data) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + largepred = [] + for i in range(len(pred)): + largepred.append(F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True)) + + cons_loss = self.cons_lambda * self.get_cons_loss(pred[0], pred[1]) #Not taking consistency with ensemble + if not self.training: + return pred + else: # training + if get_sup_loss: + # l1 = self.get_l1_loss(masks, data[0].get_device()) / b + # l1_loss = self.l1_lambda * l1 + cons_loss = self.cons_lambda * self.get_cons_loss(pred[0], pred[1]) #Not taking consistency with ensemble + sup_loss = self.get_sup_loss(largepred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return largepred, [sup_loss, cons_loss] + else: + return largepred + + def get_cons_loss(self, b1, b2): + #b1 and b2 are [batchsize x num_classes x p x p] where p depends on encoder + assert b1.shape[1] == self.num_classes + b1 = b1.reshape(-1, self.num_classes) + b2 = b2.reshape(-1, self.num_classes) #JSD loss expects batch_size x SoftMaxDimension + return self.cons_loss(b1, b2) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322884852eb34df56f49954a982ddf457b6a6027 Binary files /dev/null and b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaemaskedcons/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7c729cf9fae7f3a46e25b22ddff23e3c6276fa Binary files /dev/null and b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def58d01649a2edae243ef767e88add083910990 Binary files /dev/null and b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusemaemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaemaskedcons/mix_transformer.py b/models/segmentation_models/linearfusemaemaskedcons/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a21cdddb3e946f4a598f4906f99604dfdb599592 --- /dev/null +++ b/models/segmentation_models/linearfusemaemaskedcons/mix_transformer.py @@ -0,0 +1,537 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_mean(self, x): + # print(x.shape) + avg = torch.mean(x, dim = 1) + avg = avg.clone().detach().requires_grad_(False) + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - self.masking_ratio)) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + + # keep the first subset + ids_mask = ids_shuffle[:, len_keep:] + # for i in range(N): + # x[i][ids_mask[i]] = avg[i] + # return x + avg = avg.unsqueeze(dim = 1) + avg = avg.repeat(1, L, 1) + mask = ids_mask.unsqueeze(dim = 2) + mask = mask.repeat(1, 1, D) + masked = torch.scatter(x, dim = 1, index = mask, src = avg) # self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1 + # self.printcheck(x[0], masked[0], avg[0]) + return masked + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + + def printcheck(self, x, masked, avg): + L, D = x.shape + same = 0 + avgsame = 0 + for i in range(L): + if (x[i] == masked[i]).all(): + same += 1 + # else: + # print(i, x[i]) + if (masked[i].data == avg[i].data).all(): + avgsame += 1 + print(same, avgsame) + return + + def forward(self, x, masking_branch = -1, range_batches_to_mask = None): + assert masking_branch < num_parallel and masking_branch >= -1 + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if not masking_branch == -1: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + # stage 1 + x, H, W = self.patch_embed1(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, masking_branch = -1) + # x, H, W = self.patch_embed2(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, masking_branch = -1) + # x, H, W = self.patch_embed3(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, masking_branch = -1) + # x, H, W = self.patch_embed4(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x, masking_branch, range_batches_to_mask): + x = self.forward_features(x, masking_branch, range_batches_to_mask) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaemaskedcons/modules.py b/models/segmentation_models/linearfusemaemaskedcons/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusemaemaskedcons/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaemaskedcons/segformer.py b/models/segmentation_models/linearfusemaemaskedcons/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..05dacafb5099474f1820b6162fb034ee8a9ff855 --- /dev/null +++ b/models/segmentation_models/linearfusemaemaskedcons/segformer.py @@ -0,0 +1,263 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class SegFormerReconstructionHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, **kwargs): + super(SegFormerReconstructionHead, self).__init__() + self.in_channels = in_channels + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, 3, kernel_size=1) + + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + return x + +class LinearFusionMAEMaskedConsistency(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio, masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + print("Load pretrained weights from " + config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + else: + print("Train from scratch") + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.decoder_reconstruct_rgb = SegFormerReconstructionHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.decoder_reconstruct_depth = SegFormerReconstructionHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if not mask: + masking_branch = -1 + else: + masking_branch = int((torch.rand(1)<0.5)*1) + + x = self.encoder(data, masking_branch, range_batches_to_mask) + if self.training: + #Reconstruction branch + reconstruction_criterion = nn.MSELoss() + + encoder_output_0 = [x[0][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[0]))] + pred_reconstruct_0 = self.decoder_reconstruct_rgb(encoder_output_0) + pred_reconstruct_0 = F.interpolate(pred_reconstruct_0, size=(h, w), mode='bilinear', align_corners=True) + encoder_output_1 = [x[1][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[1]))] + pred_reconstruct_1 = self.decoder_reconstruct_depth(encoder_output_1) + pred_reconstruct_1 = F.interpolate(pred_reconstruct_1, size=(h, w), mode='bilinear', align_corners=True) + reconstruction_loss = (1 - masking_branch) * reconstruction_criterion(pred_reconstruct_0, data[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + reconstruction_loss += masking_branch * reconstruction_criterion(pred_reconstruct_1, data[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # if masking_branch == 0: + # masked_encoder_output = [x[0][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[0]))] + # pred_reconstruct = self.decoder_reconstruct_rgb(masked_encoder_output) + # pred_reconstruct = F.interpolate(pred_reconstruct, size=(h, w), mode='bilinear', align_corners=True) + # else: + # masked_encoder_output = [x[1][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[1]))] + # pred_reconstruct = self.decoder_reconstruct_depth(masked_encoder_output) + # pred_reconstruct = F.interpolate(pred_reconstruct, size=(h, w), mode='bilinear', align_corners=True) + + # reconstruction_criterion = nn.MSELoss() + # reconstruction_loss = reconstruction_criterion(pred_reconstruct, data[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + + #Segmentation branch + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss, reconstruction_loss, masking_branch + else: + return pred, reconstruction_loss, masking_branch + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32f6656d4bcb516310f54e9ed1542ce3df219562 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusemaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedcons/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusemaskedcons/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38cf30e273f58b7a928d62cd312659288125e021 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedcons/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e920f8915dee9894d28f72c224c82a3f998c6158 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedcons/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedcons/mix_transformer.py b/models/segmentation_models/linearfusemaskedcons/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a21cdddb3e946f4a598f4906f99604dfdb599592 --- /dev/null +++ b/models/segmentation_models/linearfusemaskedcons/mix_transformer.py @@ -0,0 +1,537 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_mean(self, x): + # print(x.shape) + avg = torch.mean(x, dim = 1) + avg = avg.clone().detach().requires_grad_(False) + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - self.masking_ratio)) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + + # keep the first subset + ids_mask = ids_shuffle[:, len_keep:] + # for i in range(N): + # x[i][ids_mask[i]] = avg[i] + # return x + avg = avg.unsqueeze(dim = 1) + avg = avg.repeat(1, L, 1) + mask = ids_mask.unsqueeze(dim = 2) + mask = mask.repeat(1, 1, D) + masked = torch.scatter(x, dim = 1, index = mask, src = avg) # self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1 + # self.printcheck(x[0], masked[0], avg[0]) + return masked + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + + def printcheck(self, x, masked, avg): + L, D = x.shape + same = 0 + avgsame = 0 + for i in range(L): + if (x[i] == masked[i]).all(): + same += 1 + # else: + # print(i, x[i]) + if (masked[i].data == avg[i].data).all(): + avgsame += 1 + print(same, avgsame) + return + + def forward(self, x, masking_branch = -1, range_batches_to_mask = None): + assert masking_branch < num_parallel and masking_branch >= -1 + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if not masking_branch == -1: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + # stage 1 + x, H, W = self.patch_embed1(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, masking_branch = -1) + # x, H, W = self.patch_embed2(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, masking_branch = -1) + # x, H, W = self.patch_embed3(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, masking_branch = -1) + # x, H, W = self.patch_embed4(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x, masking_branch, range_batches_to_mask): + x = self.forward_features(x, masking_branch, range_batches_to_mask) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaskedcons/modules.py b/models/segmentation_models/linearfusemaskedcons/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusemaskedcons/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaskedcons/segformer.py b/models/segmentation_models/linearfusemaskedcons/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2d89aa1a0abc82ffb82dac8dd49f85ce8ce996 --- /dev/null +++ b/models/segmentation_models/linearfusemaskedcons/segformer.py @@ -0,0 +1,177 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class LinearFusionMaskedConsistency(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio, masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + print("Load pretrained weights from " + config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + else: + print("Train from scratch") + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if not mask: + masking_branch = -1 + else: + masking_branch = int((torch.rand(1)<0.5)*1) + # masking_branch = 1 + x = self.encoder(data, masking_branch, range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss, masking_branch + else: + return pred, masking_branch + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cc539fd926da2aa7e26c1ce97ff60edf26c7111 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-38.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d194c15b3821fb37033ccd6ec8f8321daa41070 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15a944a5bd04bdc366fa595c17107690fb7ea4d6 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-38.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d506978c00b18bd5b46e1b467ac06de717a2047 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/modules.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7123869e8d091b22f20a706317966a49105fb5ff Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-38.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a43a9c13f24fe4fa6cb2b8b4754c98867c7b193 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer.cpython-38.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusemaskedconsmixbatch/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/mix_transformer.py b/models/segmentation_models/linearfusemaskedconsmixbatch/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..291856416cecd9c53df6378bab4ba36c19f751ca --- /dev/null +++ b/models/segmentation_models/linearfusemaskedconsmixbatch/mix_transformer.py @@ -0,0 +1,542 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_mean(self, x): + # print(x.shape) + avg = torch.mean(x, dim = 1) + avg = avg.clone().detach().requires_grad_(False) + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - self.masking_ratio)) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + + # keep the first subset + ids_mask = ids_shuffle[:, len_keep:] + # for i in range(N): + # x[i][ids_mask[i]] = avg[i] + # return x + avg = avg.unsqueeze(dim = 1) + avg = avg.repeat(1, L, 1) + mask = ids_mask.unsqueeze(dim = 2) + mask = mask.repeat(1, 1, D) + masked = torch.scatter(x, dim = 1, index = mask, src = avg) # self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1 + # self.printcheck(x[0], masked[0], avg[0]) + return masked + + def mask_with_learnt_mask(self, x, masking_branch): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + _, N, L, D = x.shape # modality, batch, length, dim + N = torch.sum( torch.tensor(masking_branch) != -1) + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + # x[indicies] = self.mask_token + masking_branch = torch.tensor(masking_branch).to(x.device) + index = torch.stack([torch.tensor(masking_branch == 0), torch.tensor(masking_branch == 1)]).to(x.device) + xtemp = x[index] + xtemp[indicies] = self.mask_token + x[index] = xtemp + return x + + + def printcheck(self, x, masked, avg): + L, D = x.shape + same = 0 + avgsame = 0 + for i in range(L): + if (x[i] == masked[i]).all(): + same += 1 + # else: + # print(i, x[i]) + if (masked[i].data == avg[i].data).all(): + avgsame += 1 + print(same, avgsame) + return + + def forward(self, x, mask, masking_branch = None, range_batches_to_mask = None): + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if mask: + assert masking_branch is not None and range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + xstacked = torch.stack(x) + xstacked[:, range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(xstacked[:, range_batches_to_mask[0]:range_batches_to_mask[1]], masking_branch) + x = [xstacked[i] for i in range(xstacked.shape[0])] + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, mask, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + # stage 1 + + x, H, W = self.patch_embed1(x, mask = mask, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, mask = False) + # x, H, W = self.patch_embed2(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, mask = False) + # x, H, W = self.patch_embed3(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, mask = False) + # x, H, W = self.patch_embed4(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x, mask, masking_branch = None, range_batches_to_mask = None): + x = self.forward_features(x, mask, masking_branch, range_batches_to_mask) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/modules.py b/models/segmentation_models/linearfusemaskedconsmixbatch/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusemaskedconsmixbatch/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusemaskedconsmixbatch/segformer.py b/models/segmentation_models/linearfusemaskedconsmixbatch/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6545ed5e50ea36a1d19de9491064877fc53edccf --- /dev/null +++ b/models/segmentation_models/linearfusemaskedconsmixbatch/segformer.py @@ -0,0 +1,184 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class LinearFusionMaskedConsistencyMixBatch(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio, masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, masking_branch = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if mask: + range_batches_to_mask = [0, b] + n = range_batches_to_mask[1] - range_batches_to_mask[0] + masking_branch = [masking_branch for _ in range(n)] + else: + masking_branch = None + range_batches_to_mask = None + # mask = True + # masking_branch = [1 for _ in range(data[0].shape[0])] + # range_batches_to_mask = [0, data[0].shape[0]] + + x = self.encoder(data, mask, masking_branch, range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + # return pred, sup_loss, masking_branch + return pred, sup_loss + else: + # return pred, masking_branch + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a542fa50d7c95dfb7361e915228fd3930a00b72 Binary files /dev/null and b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b63ecab4ebda70d1a769e2860d44791e669f272 Binary files /dev/null and b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a4473e7dc0294ebacf34f3dc41583ea6a1e2aa Binary files /dev/null and b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/linearfusesepdecodermaskedcons/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/mix_transformer.py b/models/segmentation_models/linearfusesepdecodermaskedcons/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..65fa4ceb31cebebe8b0f120f8f18defc8841f5b5 --- /dev/null +++ b/models/segmentation_models/linearfusesepdecodermaskedcons/mix_transformer.py @@ -0,0 +1,538 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, LinearFuse + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, ratio, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = LinearFuse(ratio) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, ratio, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, ratio, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_mean(self, x): + # print(x.shape) + avg = torch.mean(x, dim = 1) + avg = avg.clone().detach().requires_grad_(False) + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - self.masking_ratio)) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + + # keep the first subset + ids_mask = ids_shuffle[:, len_keep:] + # for i in range(N): + # x[i][ids_mask[i]] = avg[i] + # return x + avg = avg.unsqueeze(dim = 1) + avg = avg.repeat(1, L, 1) + mask = ids_mask.unsqueeze(dim = 2) + mask = mask.repeat(1, 1, D) + masked = torch.scatter(x, dim = 1, index = mask, src = avg) # self[i] [index[i][j][k]] [k] = src[i][j][k] # if dim == 1 + # self.printcheck(x[0], masked[0], avg[0]) + return masked + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + + def printcheck(self, x, masked, avg): + L, D = x.shape + same = 0 + avgsame = 0 + for i in range(L): + if (x[i] == masked[i]).all(): + same += 1 + # else: + # print(i, x[i]) + if (masked[i].data == avg[i].data).all(): + avgsame += 1 + print(same, avgsame) + return + + def forward(self, x, masking_branch = -1, range_batches_to_mask = None): + assert masking_branch < num_parallel and masking_branch >= -1 + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if not masking_branch == -1: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, ratio, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], ratio = ratio, num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], ratio = ratio, num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], ratio = ratio, num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], ratio = ratio, num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + # masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + # x, H, W = self.patch_embed2(x, masking_branch = -1) + x, H, W = self.patch_embed2(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + # x, H, W = self.patch_embed3(x, masking_branch = -1) + x, H, W = self.patch_embed3(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + # x, H, W = self.patch_embed4(x, masking_branch = -1) + x, H, W = self.patch_embed4(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1] + + def forward(self, x, masking_branch, range_batches_to_mask): + x = self.forward_features(x, masking_branch, range_batches_to_mask) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b0, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b1, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b2, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b3, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b4, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, ratio, masking_ratio, **kwargs): + super(mit_b5, self).__init__(ratio = ratio, masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/modules.py b/models/segmentation_models/linearfusesepdecodermaskedcons/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..548222552ce02a9e31e4d347d96ec425e5012f6f --- /dev/null +++ b/models/segmentation_models/linearfusesepdecodermaskedcons/modules.py @@ -0,0 +1,44 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class LinearFuse(nn.Module): + def __init__(self, ratio): + super(LinearFuse, self).__init__() + self.ratio = ratio + + def forward(self, x): + # x: [B, N, C], mask: [B, N] + # mask_threshold = torch.rand(1)[0] * (high - low) + low + # mask = [torch.rand(x[0].shape[0], x[0].shape[1]), torch.rand(x[0].shape[0], x[0].shape[1])] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0 = self.ratio*x[0] + (1-self.ratio)*x[1] + x1 = (1-self.ratio)*x[0] + self.ratio*x[1] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/linearfusesepdecodermaskedcons/segformer.py b/models/segmentation_models/linearfusesepdecodermaskedcons/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f64a95a209ea0fbbd1348e84bf3af12cc40fa5cc --- /dev/null +++ b/models/segmentation_models/linearfusesepdecodermaskedcons/segformer.py @@ -0,0 +1,178 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class LinearFusionSepDecoderMaskedConsistency(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(ratio = config.ratio, masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder1 = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + self.decoder2 = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + self.ratio = config.ratio + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder1.parameters()): + param_groups[2].append(param) + for param in list(self.decoder2.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if not mask: + masking_branch = -1 + else: + masking_branch = int((torch.rand(1)<0.5)*1) + x = self.encoder(data, masking_branch, range_batches_to_mask) + pred = [self.decoder1(x[0]), self.decoder2(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss, masking_branch + else: + return pred, masking_branch + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/refinenet.py b/models/segmentation_models/refinenet.py new file mode 100644 index 0000000000000000000000000000000000000000..67c17007188aaa9116ac8674b20783b7020b9481 --- /dev/null +++ b/models/segmentation_models/refinenet.py @@ -0,0 +1,579 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import SyncBatchNorm as BatchNorm2d +import re +import os, sys +from six import moves + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class BatchNorm2dParallel(nn.Module): + def __init__(self, num_features, num_parallel): + super(BatchNorm2dParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'bn_' + str(i), BatchNorm2d(num_features)) + + def forward(self, x_parallel): + return [getattr(self, 'bn_' + str(i))(x) for i, x in enumerate(x_parallel)] + +class MyRefineNet(nn.Module): + def __init__(self, num_layers, num_classes): + super(MyRefineNet, self).__init__() + self.model = refinenet(num_layers, num_classes, 1, None) #Passing num_parallel = 1 and bn_threshold as None + self.model = model_init(self.model, num_layers, 1, imagenet=True) #Only initializes the encoder + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + pred = self.model(data) + pred = F.interpolate(pred[0], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: # return just predictions for evaluation + return pred + else: + if get_sup_loss: + return pred, self.get_sup_loss(pred, gt, criterion) + else: + return pred + + def get_params(self): + enc_params, dec_params= [], [] + for name, param in self.model.named_parameters(): + if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): + enc_params.append(param) + else: + dec_params.append(param) + return enc_params, dec_params + + def get_sup_loss(self, pred, gt, criterion): + pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + return criterion(pred, gt) + +"""RefineNet-LightWeight + +RefineNet-LigthWeight PyTorch for non-commercial purposes + +Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +models_urls = { + '101_voc' : 'https://cloudstor.aarnet.edu.au/plus/s/Owmttk9bdPROwc6/download', + + '18_imagenet' : 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + '50_imagenet' : 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + '101_imagenet': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + '152_imagenet': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + +bottleneck_idx = 0 +save_idx = 0 + + +def conv3x3(in_planes, out_planes, stride=1, bias=False): + "3x3 convolution with padding" + return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=bias)) + + +def conv1x1(in_planes, out_planes, stride=1, bias=False): + "1x1 convolution" + return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=1, + stride=stride, padding=0, bias=bias)) + + +class CRPBlock(nn.Module): + def __init__(self, in_planes, out_planes, num_stages, num_parallel): + super(CRPBlock, self).__init__() + for i in range(num_stages): + setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), + conv3x3(in_planes if (i == 0) else out_planes, out_planes)) + self.stride = 1 + self.num_stages = num_stages + self.num_parallel = num_parallel + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=5, stride=1, padding=2)) + + def forward(self, x): + top = x + for i in range(self.num_stages): + top = self.maxpool(top) + top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) + x = [x[l] + top[l] for l in range(self.num_parallel)] + return x + + +stages_suffixes = {0 : '_conv', 1 : '_conv_relu_varout_dimred'} + +class RCUBlock(nn.Module): + def __init__(self, in_planes, out_planes, num_blocks, num_stages, num_parallel): + super(RCUBlock, self).__init__() + for i in range(num_blocks): + for j in range(num_stages): + setattr(self, '{}{}'.format(i + 1, stages_suffixes[j]), + conv3x3(in_planes if (i == 0) and (j == 0) else out_planes, + out_planes, bias=(j == 0))) + self.stride = 1 + self.num_blocks = num_blocks + self.num_stages = num_stages + self.num_parallel = num_parallel + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + + def forward(self, x): + for i in range(self.num_blocks): + residual = x + for j in range(self.num_stages): + x = self.relu(x) + x = getattr(self, '{}{}'.format(i + 1, stages_suffixes[j]))(x) + x = [x[l] + residual[l] for l in range(self.num_parallel)] + return x + + +class BasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2dParallel(planes, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2dParallel(planes, num_parallel) + self.num_parallel = num_parallel + self.downsample = downsample + self.stride = stride + + self.bn_threshold = bn_threshold + self.bn2_list = [] + for module in self.bn2.modules(): + if isinstance(module, BatchNorm2d): + self.bn2_list.append(module) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = [out[l] + residual[l] for l in range(self.num_parallel)] + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = BatchNorm2dParallel(planes, num_parallel) + self.conv2 = conv3x3(planes, planes, stride=stride) + self.bn2 = BatchNorm2dParallel(planes, num_parallel) + self.conv3 = conv1x1(planes, planes * 4) + self.bn3 = BatchNorm2dParallel(planes * 4, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.num_parallel = num_parallel + self.downsample = downsample + self.stride = stride + + self.bn_threshold = bn_threshold + self.bn2_list = [] + for module in self.bn2.modules(): + if isinstance(module, BatchNorm2d): + self.bn2_list.append(module) + + def forward(self, x): + residual = x + out = x + + out = self.conv1(out) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = [out[l] + residual[l] for l in range(self.num_parallel)] + out = self.relu(out) + + return out + + +class RefineNet(nn.Module): + def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): + self.inplanes = 64 + self.num_parallel = num_parallel + super(RefineNet, self).__init__() + self.dropout = ModuleParallel(nn.Dropout(p=0.5)) + self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) + self.bn1 = BatchNorm2dParallel(64, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) + self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) + + self.p_ims1d2_outl1_dimred = conv3x3(2048, 512) + self.adapt_stage1_b = self._make_rcu(512, 512, 2, 2) + self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) + self.mflow_conv_g1_b = self._make_rcu(512, 512, 3, 2) + self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(512, 256) + + self.p_ims1d2_outl2_dimred = conv3x3(1024, 256) + self.adapt_stage2_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage2_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g2_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(256, 256) + + self.p_ims1d2_outl3_dimred = conv3x3(512, 256) + self.adapt_stage3_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage3_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g3_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(256, 256) + + self.p_ims1d2_outl4_dimred = conv3x3(256, 256) + self.adapt_stage4_b = self._make_rcu(256, 256, 2, 2) + self.adapt_stage4_b2_joint_varout_dimred = conv3x3(256, 256) + self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g4_b = self._make_rcu(256, 256, 3, 2) + + self.clf_conv = conv3x3(256, num_classes, bias=True) + + def _make_crp(self, in_planes, out_planes, num_stages): + layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): + layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride=stride), + BatchNorm2dParallel(planes * block.expansion, self.num_parallel) + ) + + layers = [] + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + + l4 = self.dropout(l4) + l3 = self.dropout(l3) + + x4 = self.p_ims1d2_outl1_dimred(l4) + x4 = self.adapt_stage1_b(x4) + x4 = self.relu(x4) + x4 = self.mflow_conv_g1_pool(x4) + x4 = self.mflow_conv_g1_b(x4) + x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) + x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] + + x3 = self.p_ims1d2_outl2_dimred(l3) + x3 = self.adapt_stage2_b(x3) + x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) + x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] + x3 = self.relu(x3) + x3 = self.mflow_conv_g2_pool(x3) + x3 = self.mflow_conv_g2_b(x3) + x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) + x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] + + x2 = self.p_ims1d2_outl3_dimred(l2) + x2 = self.adapt_stage3_b(x2) + x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) + x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] + x2 = self.relu(x2) + x2 = self.mflow_conv_g3_pool(x2) + x2 = self.mflow_conv_g3_b(x2) + x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) + x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] + + x1 = self.p_ims1d2_outl4_dimred(l1) + x1 = self.adapt_stage4_b(x1) + x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) + x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] + x1 = self.relu(x1) + x1 = self.mflow_conv_g4_pool(x1) + x1 = self.mflow_conv_g4_b(x1) + x1 = self.dropout(x1) + + out = self.clf_conv(x1) + return out + + +class RefineNet_Resnet18(nn.Module): + def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): + self.inplanes = 64 + self.num_parallel = num_parallel + super(RefineNet_Resnet18, self).__init__() + self.dropout = ModuleParallel(nn.Dropout(p=0.5)) + self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) + self.bn1 = BatchNorm2dParallel(64, num_parallel) + self.relu = ModuleParallel(nn.ReLU(inplace=True)) + self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) + self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) + + self.p_ims1d2_outl1_dimred = conv3x3(512, 256) + self.adapt_stage1_b = self._make_rcu(256, 256, 2, 2) + self.mflow_conv_g1_pool = self._make_crp(256, 256, 4) + self.mflow_conv_g1_b = self._make_rcu(256, 256, 3, 2) + self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(256, 64) + + self.p_ims1d2_outl2_dimred = conv3x3(256, 64) + self.adapt_stage2_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage2_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g2_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g2_b = self._make_rcu(64, 64, 3, 2) + self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(64, 64) + + self.p_ims1d2_outl3_dimred = conv3x3(128, 64) + self.adapt_stage3_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage3_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g3_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g3_b = self._make_rcu(64, 64, 3, 2) + self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(64, 64) + + self.p_ims1d2_outl4_dimred = conv3x3(64, 64) + self.adapt_stage4_b = self._make_rcu(64, 64, 2, 2) + self.adapt_stage4_b2_joint_varout_dimred = conv3x3(64, 64) + self.mflow_conv_g4_pool = self._make_crp(64, 64, 4) + self.mflow_conv_g4_b = self._make_rcu(64, 64, 3, 2) + + self.clf_conv = conv3x3(64, num_classes, bias=True) + + def _make_crp(self, in_planes, out_planes, num_stages): + layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): + layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] + return nn.Sequential(*layers) + + def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride=stride), + BatchNorm2dParallel(planes * block.expansion, self.num_parallel) + ) + + layers = [] + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + + l4 = self.dropout(l4) + l3 = self.dropout(l3) + + x4 = self.p_ims1d2_outl1_dimred(l4) + x4 = self.adapt_stage1_b(x4) + x4 = self.relu(x4) + x4 = self.mflow_conv_g1_pool(x4) + x4 = self.mflow_conv_g1_b(x4) + x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) + x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] + + x3 = self.p_ims1d2_outl2_dimred(l3) + x3 = self.adapt_stage2_b(x3) + x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) + x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] + x3 = self.relu(x3) + x3 = self.mflow_conv_g2_pool(x3) + x3 = self.mflow_conv_g2_b(x3) + x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) + x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] + + x2 = self.p_ims1d2_outl3_dimred(l2) + x2 = self.adapt_stage3_b(x2) + x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) + x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] + x2 = self.relu(x2) + x2 = self.mflow_conv_g3_pool(x2) + x2 = self.mflow_conv_g3_b(x2) + x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) + x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] + + x1 = self.p_ims1d2_outl4_dimred(l1) + x1 = self.adapt_stage4_b(x1) + x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) + x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] + x1 = self.relu(x1) + x1 = self.mflow_conv_g4_pool(x1) + x1 = self.mflow_conv_g4_b(x1) + x1 = self.dropout(x1) + + out = self.clf_conv(x1) + + return out + + +def refinenet(num_layers, num_classes, num_parallel, bn_threshold): + refinnetClass = RefineNet + if int(num_layers) == 18: + layers = [2, 2, 2, 2] + block = BasicBlock + refinnetClass = RefineNet_Resnet18 + elif int(num_layers) == 50: + layers = [3, 4, 6, 3] + block = Bottleneck + elif int(num_layers) == 101: + layers = [3, 4, 23, 3] + block = Bottleneck + elif int(num_layers) == 152: + layers = [3, 8, 36, 3] + block = Bottleneck + else: + print('invalid num_layers') + + model = refinnetClass(block, layers, num_parallel, num_classes, bn_threshold) + return model + +def maybe_download(model_name, model_url, model_dir=None, map_location=None): + if model_dir is None: + torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) + model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = '{}.pth.tar'.format(model_name) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + url = model_url + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + moves.urllib.request.urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) + +def model_init(model, num_layers, num_parallel, imagenet=False): + if imagenet: + key = str(num_layers) + '_imagenet' + url = models_urls[key] + state_dict = maybe_download(key, url) + model_dict = expand_model_dict(model.state_dict(), state_dict, num_parallel) + model.load_state_dict(model_dict, strict=True) + return model + +def expand_model_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + bn = '.bn_%d' % i + replace = True if bn in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(bn, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def get_params(model): + enc_params, dec_params, slim_params = [], [], [] + for name, param in model.named_parameters(): + if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): + enc_params.append(param) + # if args.print_network: + # print(' Enc. parameter: {}'.format(name)) + else: + dec_params.append(param) + # if args.print_network: + # print(' Dec. parameter: {}'.format(name)) + if param.requires_grad and name.endswith('weight') and 'bn2' in name: + if len(slim_params) % 2 == 0: + slim_params.append(param[:len(param) // 2]) + else: + slim_params.append(param[len(param) // 2:]) + return enc_params, dec_params, slim_params + +def get_sup_loss_from_output(criterion, outputs, target): + target = target.squeeze(3) + loss = 0 + for output in outputs: + output = nn.functional.interpolate(output, size=target.size()[1:], + mode='bilinear', align_corners=False) + # soft_output = torch.nn.functional.log_softmax(output, dim = 1) + # Compute loss and backpropagate + # print(soft_output.shape, target.shape, "Shapes") + # print(soft_output[0, :, 0, 0], target[0, 0, 0], "values") + # loss += criterion(soft_output, target) + # print(output[0, :, 0, 0], target[0, 0, 0], "values") + loss += criterion(output, target) + return loss/len(outputs) + +def L1_penalty(var): + return torch.abs(var).sum()#.to(var.get_device()) \ No newline at end of file diff --git a/models/segmentation_models/segformer/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/segformer/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e51e19c91138564a0a74db9a74665de66ac3308f Binary files /dev/null and b/models/segmentation_models/segformer/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/segformer/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/segformer/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa7b716208170af233bd568c2baa27132aee679b Binary files /dev/null and b/models/segmentation_models/segformer/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/segformer/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/segformer/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebbfe99a252e1edb48dd8c60b16deb6d9201288a Binary files /dev/null and b/models/segmentation_models/segformer/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/segformer/mix_transformer.py b/models/segmentation_models/segformer/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e50496d72c4f618e04b9fb4c133ddbf02e41002 --- /dev/null +++ b/models/segmentation_models/segformer/mix_transformer.py @@ -0,0 +1,453 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + # k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + k, v = [kv[0][0]], [kv[0][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0 = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + return [outs0] + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/segformer/modules.py b/models/segmentation_models/segformer/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c365cad5d30db1bb13df519700987470585b7ef1 --- /dev/null +++ b/models/segmentation_models/segformer/modules.py @@ -0,0 +1,29 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 1 + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/segformer/segformer.py b/models/segmentation_models/segformer/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4795bbc5bef3a79606d8fe441b140029327c419d --- /dev/null +++ b/models/segmentation_models/segformer/segformer.py @@ -0,0 +1,166 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class SegFormer(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)() + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + def get_params(self): #TODO: Check why norm gets a 0 weight_decay and how that affects + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + data = self.encoder(data) + pred = self.decoder(data[0]) #single modality model + pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) + if not self.training: #Return the ensemble predictions for evaluation + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + return pred, sup_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + pred = pred[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + return criterion(pred, gt) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = SegFormer('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusion/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/tokenfusion/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91e360687224564b686b11d9128ecdd09c135b87 Binary files /dev/null and b/models/segmentation_models/tokenfusion/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusion/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusion/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/tokenfusion/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusion/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/tokenfusion/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4498362672452edac73e9203ad7be18d2ce85fec Binary files /dev/null and b/models/segmentation_models/tokenfusion/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusion/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/tokenfusion/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8491c15c837477ee83cbe4a9972a7702176a9b13 Binary files /dev/null and b/models/segmentation_models/tokenfusion/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusion/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusion/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/tokenfusion/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusion/mix_transformer.py b/models/segmentation_models/tokenfusion/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..85d8012e1a297d7bb2f4351dd3109bf79722708e --- /dev/null +++ b/models/segmentation_models/tokenfusion/mix_transformer.py @@ -0,0 +1,475 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x): + x, masks = self.forward_features(x) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/tokenfusion/mix_transformer_analysis.py b/models/segmentation_models/tokenfusion/mix_transformer_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..762baca2e8618192bcf4cf9d2065e157331367ea --- /dev/null +++ b/models/segmentation_models/tokenfusion/mix_transformer_analysis.py @@ -0,0 +1,475 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x): + x, masks = self.forward_features(x) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/tokenfusion/modules.py b/models/segmentation_models/tokenfusion/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..709afc5dfdb1b52136fe4e1750ba6953aa4f71d4 --- /dev/null +++ b/models/segmentation_models/tokenfusion/modules.py @@ -0,0 +1,43 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class TokenExchange(nn.Module): + def __init__(self): + super(TokenExchange, self).__init__() + + def forward(self, x, mask, mask_threshold): + # x: [B, N, C], mask: [B, N, 1] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] + x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] + x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] + x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/tokenfusion/segformer.py b/models/segmentation_models/tokenfusion/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e9981bafd97fcb73e84119bf67824f4d732607 --- /dev/null +++ b/models/segmentation_models/tokenfusion/segformer.py @@ -0,0 +1,192 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class WeTr(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)() + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x, masks = self.encoder(data) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + l1 = self.get_l1_loss(masks, data[0].get_device()) / b + l1_loss = self.l1_lambda * l1 + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + l1_loss + else: + return pred + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = WeTr('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusion/segformer_analysis.py b/models/segmentation_models/tokenfusion/segformer_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..9f8947eceed9d9633c32154e2631df3b61bc50ee --- /dev/null +++ b/models/segmentation_models/tokenfusion/segformer_analysis.py @@ -0,0 +1,189 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer_analysis as mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class WeTr(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)() + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x, masks = self.encoder(data) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + if not self.training: #Return the ensemble predictions for evaluation + return pred, masks + + else: # training + if get_sup_loss: + l1_loss = self.l1_lambda * self.get_l1_loss(masks, data[0].get_device()) + sup_loss = self.get_sup_loss(pred, gt, criterion) + return pred, sup_loss + l1_loss + else: + return pred + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = WeTr('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f1e8c56a865fe574b05aed26450891d8660bd2 Binary files /dev/null and b/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/tokenfusionbothmask/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionbothmask/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/tokenfusionbothmask/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d48cc2b0422f694e09e190c5d2e9c16f14a8c1ac Binary files /dev/null and b/models/segmentation_models/tokenfusionbothmask/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194373d76af92bd78e3fafff601b7f2ca9b3db77 Binary files /dev/null and b/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/tokenfusionbothmask/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionbothmask/mix_transformer.py b/models/segmentation_models/tokenfusionbothmask/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4ced86c7dbe20e9a4728a3689799ca5da99df2 --- /dev/null +++ b/models/segmentation_models/tokenfusionbothmask/mix_transformer.py @@ -0,0 +1,498 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + def forward(self, x, mask = False, range_batches_to_mask = None): + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if mask: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[0][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[0][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[1][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[1][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, mask, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + # x, H, W = self.patch_embed1(x) + x, H, W = self.patch_embed1(x, mask = mask, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, mask = False) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, mask = False) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, mask = False) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x, mask, range_batches_to_mask): + x, masks = self.forward_features(x, mask, range_batches_to_mask) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b0, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b1, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b2, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b3, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b4, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b5, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionbothmask/modules.py b/models/segmentation_models/tokenfusionbothmask/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..709afc5dfdb1b52136fe4e1750ba6953aa4f71d4 --- /dev/null +++ b/models/segmentation_models/tokenfusionbothmask/modules.py @@ -0,0 +1,43 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class TokenExchange(nn.Module): + def __init__(self): + super(TokenExchange, self).__init__() + + def forward(self, x, mask, mask_threshold): + # x: [B, N, C], mask: [B, N, 1] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] + x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] + x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] + x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionbothmask/segformer.py b/models/segmentation_models/tokenfusionbothmask/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9e925d52440c16400b4159f0a5ee5a2a1396594a --- /dev/null +++ b/models/segmentation_models/tokenfusionbothmask/segformer.py @@ -0,0 +1,184 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class TokenFusionBothMask(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x, exchange_masks = self.encoder(data, mask = mask, range_batches_to_mask = range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + l1 = self.get_l1_loss(exchange_masks, data[0].get_device()) / b + l1_loss = self.l1_lambda * l1 + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + l1_loss + else: + return pred + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = TokenFusionBothMask('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96bf225f9eb95d75c8d67bf543c38b79a4ecaf4e Binary files /dev/null and b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53b15122f13bbcd5727e982e506e83c11f054465 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d21436a856cee8ad71e23719e4d8f3fdab946aa0 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaemaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/mix_transformer.py b/models/segmentation_models/tokenfusionmaemaskedconsistency/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..42320bf8c277595a2615714deb3f4d7f880ce57a --- /dev/null +++ b/models/segmentation_models/tokenfusionmaemaskedconsistency/mix_transformer.py @@ -0,0 +1,498 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + def forward(self, x, masking_branch = -1, range_batches_to_mask = None): + assert masking_branch < num_parallel and masking_branch >= -1 + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if not masking_branch == -1: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + # x, H, W = self.patch_embed1(x) + x, H, W = self.patch_embed1(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, masking_branch = -1) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, masking_branch = -1) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, masking_branch = -1) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x, masking_branch, range_batches_to_mask): + x, masks = self.forward_features(x, masking_branch, range_batches_to_mask) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b0, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b1, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b2, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b3, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b4, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b5, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/modules.py b/models/segmentation_models/tokenfusionmaemaskedconsistency/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..709afc5dfdb1b52136fe4e1750ba6953aa4f71d4 --- /dev/null +++ b/models/segmentation_models/tokenfusionmaemaskedconsistency/modules.py @@ -0,0 +1,43 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class TokenExchange(nn.Module): + def __init__(self): + super(TokenExchange, self).__init__() + + def forward(self, x, mask, mask_threshold): + # x: [B, N, C], mask: [B, N, 1] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] + x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] + x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] + x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaemaskedconsistency/segformer.py b/models/segmentation_models/tokenfusionmaemaskedconsistency/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b10f4c31915cdb6c7f57aed1a8a73d673cb993c3 --- /dev/null +++ b/models/segmentation_models/tokenfusionmaemaskedconsistency/segformer.py @@ -0,0 +1,269 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class SegFormerReconstructionHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, **kwargs): + super(SegFormerReconstructionHead, self).__init__() + self.in_channels = in_channels + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, 3, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + return x + + +class TokenFusionMAEMaskedConsistency(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + self.decoder_reconstruct_rgb = SegFormerReconstructionHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.decoder_reconstruct_depth = SegFormerReconstructionHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if not mask: + masking_branch = -1 + else: + masking_branch = int((torch.rand(1)<0.5)*1) + # masking_branch = 1 + x, exchange_masks = self.encoder(data, masking_branch, range_batches_to_mask) + if self.training: + #Reconstruction branch + reconstruction_criterion = nn.MSELoss() + + encoder_output_0 = [x[0][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[0]))] + pred_reconstruct_0 = self.decoder_reconstruct_rgb(encoder_output_0) + pred_reconstruct_0 = F.interpolate(pred_reconstruct_0, size=(h, w), mode='bilinear', align_corners=True) + encoder_output_1 = [x[1][i][range_batches_to_mask[0]:range_batches_to_mask[1]] for i in range(len(x[1]))] + pred_reconstruct_1 = self.decoder_reconstruct_depth(encoder_output_1) + pred_reconstruct_1 = F.interpolate(pred_reconstruct_1, size=(h, w), mode='bilinear', align_corners=True) + reconstruction_loss = (1 - masking_branch) * reconstruction_criterion(pred_reconstruct_0, data[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + reconstruction_loss += masking_branch * reconstruction_criterion(pred_reconstruct_1, data[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + l1 = self.get_l1_loss(exchange_masks, data[0].get_device()) / b + l1_loss = self.l1_lambda * l1 + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + l1_loss, reconstruction_loss, masking_branch + else: + return pred, reconstruction_loss, masking_branch + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = TokenFusionMAEMaskedConsistency('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10c9e34b2471b883a92ba5b4c08377c13e8adda5 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85577a0e04f3d48892404be33cd1167f2a2edfc4 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cc9b828d00df989bb13769f3ae2d52c5143318e Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistency/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/mix_transformer.py b/models/segmentation_models/tokenfusionmaskedconsistency/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..42320bf8c277595a2615714deb3f4d7f880ce57a --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistency/mix_transformer.py @@ -0,0 +1,498 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_learnt_mask(self, x): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + N, L, D = x.shape # batch, length, dim + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + x[indicies] = self.mask_token + return x + + def forward(self, x, masking_branch = -1, range_batches_to_mask = None): + assert masking_branch < num_parallel and masking_branch >= -1 + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if not masking_branch == -1: + assert range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + # x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_mean(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(x[masking_branch][range_batches_to_mask[0]:range_batches_to_mask[1]]) + # masking_branch = 1 + # x[masking_branch] = self.mask_with_mean(x[masking_branch]) + # x[masking_branch] = self.mask_with_learnt_mask(x[masking_branch]) + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + # x, H, W = self.patch_embed1(x) + x, H, W = self.patch_embed1(x, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, masking_branch = -1) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, masking_branch = -1) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, masking_branch = -1) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x, masking_branch, range_batches_to_mask): + x, masks = self.forward_features(x, masking_branch, range_batches_to_mask) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b0, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b1, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b2, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b3, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b4, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b5, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/modules.py b/models/segmentation_models/tokenfusionmaskedconsistency/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..709afc5dfdb1b52136fe4e1750ba6953aa4f71d4 --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistency/modules.py @@ -0,0 +1,43 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class TokenExchange(nn.Module): + def __init__(self): + super(TokenExchange, self).__init__() + + def forward(self, x, mask, mask_threshold): + # x: [B, N, C], mask: [B, N, 1] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] + x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] + x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] + x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistency/segformer.py b/models/segmentation_models/tokenfusionmaskedconsistency/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5768432c0227a7d78b8ebf697044402a8419fade --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistency/segformer.py @@ -0,0 +1,197 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class TokenFusionMaskedConsistency(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + if not mask: + masking_branch = -1 + else: + # masking_branch = int((torch.rand(1)<0.5)*1) + masking_branch = 1 + x, exchange_masks = self.encoder(data, masking_branch, range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + l1 = self.get_l1_loss(exchange_masks, data[0].get_device()) / b + l1_loss = self.l1_lambda * l1 + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + l1_loss, masking_branch + else: + return pred, masking_branch + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) + + +if __name__=="__main__": + # import torch.distributed as dist + # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) + pretrained_weights = torch.load('pretrained/mit_b1.pth') + wetr = TokenFusionMaskedConsistency('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() + wetr.get_param_groupsv() + dummy_input = torch.rand(2,3,512,512).cuda() + wetr(dummy_input) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39c19de5c92f7f66c6e9c3881f1f51780353425a Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1524872c5b1ba507be8d97f8071a0a5283166ab Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a96feb02e16ca3a85354049958e87b4b38b416c Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/mix_transformer.py b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d5221ff13af5d76061eb72d1918c6f901b12151a --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/mix_transformer.py @@ -0,0 +1,502 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel, TokenExchange + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W), self.dwconv(x[1], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + if mask is not None: + x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + x = self.exchange(x, mask, mask_threshold=0.02) + + return x + + +class PredictorLG(nn.Module): + """ Image to Patch Embedding from DydamicVit + """ + def __init__(self, embed_dim=384): + super().__init__() + self.score_nets = nn.ModuleList([nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim // 4), + nn.GELU(), + nn.Linear(embed_dim // 4, 2), + nn.LogSoftmax(dim=-1) + ) for _ in range(num_parallel)]) + + def forward(self, x): + x = [self.score_nets[i](x[i]) for i in range(num_parallel)] + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, mask=None): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W, mask)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbedAndMask(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, masking_ratio = 0.25, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + self.masking_ratio = masking_ratio + self.embed_dim = embed_dim + self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim), requires_grad = True)#None #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def mask_with_learnt_mask(self, x, masking_branch): + # if self.mask_token is None: #When training in the SupOnly loop, unused params raise error in DDP. Hence instantiating mask_token only when masked training begins + # self.mask_token = nn.parameter.Parameter(torch.randn(self.embed_dim, device=x.device), requires_grad = True) + # print(self.mask_token[:10], x.device, "token") + _, N, L, D = x.shape # modality, batch, length, dim + N = torch.sum( torch.tensor(masking_branch) != -1) + indicies = torch.FloatTensor(N, L).uniform_() <= self.masking_ratio + # x[indicies] = self.mask_token + masking_branch = torch.tensor(masking_branch).to(x.device) + index = torch.stack([torch.tensor(masking_branch == 0), torch.tensor(masking_branch == 1)]).to(x.device) + xtemp = x[index] + xtemp[indicies] = self.mask_token + x[index] = xtemp + return x + + + def forward(self, x, mask, masking_branch = None, range_batches_to_mask = None): + sum_mask = torch.sum(self.mask_token) + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + if mask: + assert masking_branch is not None and range_batches_to_mask is not None, "expected the range of batches to mask to not mask the labeled images" + xstacked = torch.stack(x) + xstacked[:, range_batches_to_mask[0]:range_batches_to_mask[1]] = self.mask_with_learnt_mask(xstacked[:, range_batches_to_mask[0]:range_batches_to_mask[1]], masking_branch) + x = [xstacked[i] for i in range(xstacked.shape[0])] + else: + x[0] = x[0] + 0*sum_mask #So that when training with SupOnly (and not using any masking), DDP doesn't raise an error that you have unused parameters. + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, masking_ratio, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbedAndMask(masking_ratio = masking_ratio, img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, mask, masking_branch, range_batches_to_mask): + B = x[0].shape[0] + outs0, outs1 = [], [] + masks = [] + + # stage 1 + # x, H, W = self.patch_embed1(x) + x, H, W = self.patch_embed1(x, mask = mask, masking_branch = masking_branch, range_batches_to_mask = range_batches_to_mask) + for i, blk in enumerate(self.block1): + score = self.score_predictor[0](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 2 + x, H, W = self.patch_embed2(x, mask = False) + for i, blk in enumerate(self.block2): + score = self.score_predictor[1](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 3 + x, H, W = self.patch_embed3(x, mask = False) + for i, blk in enumerate(self.block3): + score = self.score_predictor[2](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + # stage 4 + x, H, W = self.patch_embed4(x, mask = False) + for i, blk in enumerate(self.block4): + score = self.score_predictor[3](x) + mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + masks.append(mask) + x = blk(x, H, W, mask) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + outs1.append(x[1]) + + return [outs0, outs1], masks + + def forward(self, x, mask, masking_branch = None, range_batches_to_mask = None): + x, masks = self.forward_features(x, mask, masking_branch, range_batches_to_mask) + return x, masks + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b0, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b1, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b2, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b3, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b4, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, masking_ratio, **kwargs): + super(mit_b5, self).__init__(masking_ratio = masking_ratio, + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/modules.py b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..709afc5dfdb1b52136fe4e1750ba6953aa4f71d4 --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/modules.py @@ -0,0 +1,43 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 2 + + +class TokenExchange(nn.Module): + def __init__(self): + super(TokenExchange, self).__init__() + + def forward(self, x, mask, mask_threshold): + # x: [B, N, C], mask: [B, N, 1] + x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) + x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] + x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] + x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] + x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] + return [x0, x1] + + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/segformer.py b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1b37634d109a63bbd485a373e7ca7af891a2a0 --- /dev/null +++ b/models/segmentation_models/tokenfusionmaskedconsistencymixbatch/segformer.py @@ -0,0 +1,193 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class TokenFusionMaskedConsistencyMixBatch(nn.Module): + def __init__(self, backbone, config, l1_lambda, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + self.l1_lambda = l1_lambda + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + + self.encoder = getattr(mix_transformer, backbone)(masking_ratio = config.masking_ratio) + self.in_channels = self.encoder.embed_dims + ## initilize encoder + if pretrained: + state_dict = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict.pop('head.weight') + state_dict.pop('head.bias') + state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) + self.encoder.load_state_dict(state_dict, strict=True) + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None, mask = False, range_batches_to_mask = None): + b, c, h, w = data[0].shape #rgb is the 0th element + masking_branch = None + if mask: + # masking_branch = [int((torch.rand(1)<0.5)*1) for _ in range(range_batches_to_mask[0], range_batches_to_mask[1])] + n = range_batches_to_mask[1] - range_batches_to_mask[0] + masking_branch = [random.choice([0, 1, -1]) for _ in range(n)] + + # mask = True + # masking_branch = [0 for _ in range(data[0].shape[0])] + # range_batches_to_mask = [0, data[0].shape[0]] + + x, exchange_masks = self.encoder(data, mask, masking_branch, range_batches_to_mask) + pred = [self.decoder(x[0]), self.decoder(x[1])] + ens = 0 + alpha_soft = F.softmax(self.alpha) + for l in range(self.num_parallel): + ens += alpha_soft[l] * pred[l].detach() + pred.append(ens) + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + l1 = self.get_l1_loss(exchange_masks, data[0].get_device()) / b + l1_loss = self.l1_lambda * l1 + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + l1_loss + else: + return pred + + def get_l1_loss(self, masks, device): + L1_loss = 0 + for mask in masks: + L1_loss += sum([L1_penalty(m, device) for m in mask]) + return L1_loss.to(device) + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict + +def L1_penalty(var, device): + return torch.abs(var).sum().to(device) \ No newline at end of file diff --git a/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4584d150dd9778789ff380fdd50f2516fccef4a1 Binary files /dev/null and b/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/unifiedrepresentation/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentation/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/unifiedrepresentation/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c8cc41f53e72d0ebc2d118420f06a2effccca8 Binary files /dev/null and b/models/segmentation_models/unifiedrepresentation/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentation/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/unifiedrepresentation/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73677b2d7c382a9a4e48d35addf306e1534585ae Binary files /dev/null and b/models/segmentation_models/unifiedrepresentation/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentation/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/unifiedrepresentation/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/unifiedrepresentation/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentation/mix_transformer.py b/models/segmentation_models/unifiedrepresentation/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a2ab5f361b349abf8b399441791c747d3f2a35 --- /dev/null +++ b/models/segmentation_models/unifiedrepresentation/mix_transformer.py @@ -0,0 +1,463 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + # k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + k, v = [kv[0][0]], [kv[0][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0 = [] + # masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + return [outs0] + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__(patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/unifiedrepresentation/modules.py b/models/segmentation_models/unifiedrepresentation/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..af08c11a25f9d59d7bd9c5455bf3ee78fa7e404f --- /dev/null +++ b/models/segmentation_models/unifiedrepresentation/modules.py @@ -0,0 +1,28 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 1 + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/unifiedrepresentation/segformer.py b/models/segmentation_models/unifiedrepresentation/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..be92e49bb68ef09ba1fe3b2aa50dfadb0dd7b130 --- /dev/null +++ b/models/segmentation_models/unifiedrepresentation/segformer.py @@ -0,0 +1,185 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class UnifiedRepresentationNetwork(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + self.encoder0 = getattr(mix_transformer, backbone)() + self.encoder1 = getattr(mix_transformer, backbone)() + self.in_channels = self.encoder0.embed_dims + ## initilize encoder + if pretrained: + state_dict0 = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict0.pop('head.weight') + state_dict0.pop('head.bias') + state_dict0 = expand_state_dict(self.encoder0.state_dict(), state_dict0, self.num_parallel) + self.encoder0.load_state_dict(state_dict0, strict=True) + + state_dict1 = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict1.pop('head.weight') + state_dict1.pop('head.bias') + state_dict1 = expand_state_dict(self.encoder1.state_dict(), state_dict1, self.num_parallel) + self.encoder1.load_state_dict(state_dict1, strict=True) + print("Load pretrained weights from " + config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + else: + print("Train from scratch") + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + # self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + # self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder0.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for name, param in list(self.encoder1.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x0 = self.encoder0([data[0]])[0] + x1 = self.encoder1([data[1]])[0] + encoded = [] + for enc0, enc1 in zip(x0, x1): + # encoded.append(enc1) + encoded.append((enc0 + enc1) / 2) + + pred = [self.decoder(encoded)] + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer.cpython-36.pyc b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f49e5cfecce0dfb1cb5c124bb665bfd98442d935 Binary files /dev/null and b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer_analysis.cpython-36.pyc b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0f57fb8a12137d616b9938cb1e3ce5a34427ae Binary files /dev/null and b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/mix_transformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/modules.cpython-36.pyc b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1309904c3b9e5223382c33687a9bcb94bba7f4ca Binary files /dev/null and b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/modules.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer.cpython-36.pyc b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7769689073081f6afdcc9af3573212177240bb Binary files /dev/null and b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer_analysis.cpython-36.pyc b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer_analysis.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ae1d644be2f775d66c2a120c03b001d2aab80f1 Binary files /dev/null and b/models/segmentation_models/unifiedrepresentationmoddrop/__pycache__/segformer_analysis.cpython-36.pyc differ diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/mix_transformer.py b/models/segmentation_models/unifiedrepresentationmoddrop/mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a2ab5f361b349abf8b399441791c747d3f2a35 --- /dev/null +++ b/models/segmentation_models/unifiedrepresentationmoddrop/mix_transformer.py @@ -0,0 +1,463 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .modules import ModuleParallel, LayerNormParallel, num_parallel + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ModuleParallel(nn.Linear(in_features, hidden_features)) + self.dwconv = DWConv(hidden_features) + self.act = ModuleParallel(act_layer()) + self.fc2 = ModuleParallel(nn.Linear(hidden_features, out_features)) + self.drop = ModuleParallel(nn.Dropout(drop)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = [self.dwconv(x[0], H, W)] + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = ModuleParallel(nn.Linear(dim, dim, bias=qkv_bias)) + self.kv = ModuleParallel(nn.Linear(dim, dim * 2, bias=qkv_bias)) + self.attn_drop = ModuleParallel(nn.Dropout(attn_drop)) + self.proj = ModuleParallel(nn.Linear(dim, dim)) + self.proj_drop = ModuleParallel(nn.Dropout(proj_drop)) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = ModuleParallel(nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)) + self.norm = LayerNormParallel(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x[0].shape + q = self.q(x) + q = [q_.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) for q_ in q] + + if self.sr_ratio > 1: + x = [x_.permute(0, 2, 1).reshape(B, C, H, W) for x_ in x] + x = self.sr(x) + x = [x_.reshape(B, C, -1).permute(0, 2, 1) for x_ in x] + x = self.norm(x) + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + else: + kv = self.kv(x) + kv = [kv_.reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) for kv_ in kv] + # k, v = [kv[0][0], kv[1][0]], [kv[0][1], kv[1][1]] + k, v = [kv[0][0]], [kv[0][1]] + + attn = [(q_ @ k_.transpose(-2, -1)) * self.scale for (q_, k_) in zip(q, k)] + attn = [attn_.softmax(dim=-1) for attn_ in attn] + attn = self.attn_drop(attn) + + x = [(attn_ @ v_).transpose(1, 2).reshape(B, N, C) for (attn_, v_) in zip(attn, v)] + x = self.proj(x) + x = self.proj_drop(x) + + # x = [x_ * mask_.unsqueeze(2) for (x_, mask_) in zip(x, mask)] + + return x + + +# class PredictorLG(nn.Module): +# """ Image to Patch Embedding from DydamicVit +# """ +# def __init__(self, embed_dim=384): +# super().__init__() +# self.score_nets = nn.ModuleList([nn.Sequential( +# nn.LayerNorm(embed_dim), +# nn.Linear(embed_dim, embed_dim), +# nn.GELU(), +# nn.Linear(embed_dim, embed_dim // 2), +# nn.GELU(), +# nn.Linear(embed_dim // 2, embed_dim // 4), +# nn.GELU(), +# nn.Linear(embed_dim // 4, 2), +# nn.LogSoftmax(dim=-1) +# ) for _ in range(num_parallel)]) + +# def forward(self, x): +# x = [self.score_nets[i](x[i]) for i in range(num_parallel)] +# return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormParallel, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + # self.score = PredictorLG(dim) + + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = ModuleParallel(DropPath(drop_path)) if drop_path > 0. else ModuleParallel(nn.Identity()) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.exchange = TokenExchange() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B = x[0].shape[0] + # norm1 = self.norm1(x) + # score = self.score(norm1) + # mask = [F.gumbel_softmax(score_.reshape(B, -1, 2), hard=True)[:, :, 0] for score_ in score] + # if mask is not None: + # norm = [norm_ * mask_.unsqueeze(2) for (norm_, mask_) in zip(norm, mask)] + f = self.drop_path(self.attn(self.norm1(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + f = self.drop_path(self.mlp(self.norm2(x), H, W)) + x = [x_ + f_ for (x_, f_) in zip (x, f)] + # if mask is not None: + # x = self.exchange(x, mask, mask_threshold=0.02) + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = ModuleParallel(nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2))) + self.norm = LayerNormParallel(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x[0].shape + x = [x_.flatten(2).transpose(1, 2) for x_ in x] + x = self.norm(x) + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNormParallel, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # predictor_list = [PredictorLG(embed_dims[i]) for i in range(len(depths))] + # self.score_predictor = nn.ModuleList(predictor_list) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + ''' + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + ''' + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x[0].shape[0] + outs0 = [] + # masks = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + # score = self.score_predictor[0](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm1(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + # score = self.score_predictor[1](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm2(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + # score = self.score_predictor[2](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm3(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + # score = self.score_predictor[3](x) + # mask = [F.softmax(score_.reshape(B, -1, 2), dim=2)[:, :, 0] for score_ in score] # mask_: [B, N] + # masks.append(mask) + x = blk(x, H, W) + x = self.norm4(x) + x = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for x_ in x] + outs0.append(x[0]) + + return [outs0] + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2).contiguous() + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__(patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__(patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=LayerNormParallel, depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/modules.py b/models/segmentation_models/unifiedrepresentationmoddrop/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..af08c11a25f9d59d7bd9c5455bf3ee78fa7e404f --- /dev/null +++ b/models/segmentation_models/unifiedrepresentationmoddrop/modules.py @@ -0,0 +1,28 @@ +#Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. +# +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. + +import torch.nn as nn +import torch + +num_parallel = 1 + +class ModuleParallel(nn.Module): + def __init__(self, module): + super(ModuleParallel, self).__init__() + self.module = module + + def forward(self, x_parallel): + return [self.module(x) for x in x_parallel] + + +class LayerNormParallel(nn.Module): + def __init__(self, num_features): + super(LayerNormParallel, self).__init__() + for i in range(num_parallel): + setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) + + def forward(self, x_parallel): + return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] \ No newline at end of file diff --git a/models/segmentation_models/unifiedrepresentationmoddrop/segformer.py b/models/segmentation_models/unifiedrepresentationmoddrop/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..956443a3b4034942cfa8ee8af5e2bffd24296719 --- /dev/null +++ b/models/segmentation_models/unifiedrepresentationmoddrop/segformer.py @@ -0,0 +1,199 @@ +# 2022.06.08-Changed for implementation of TokenFusion +# Huawei Technologies Co., Ltd. +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import mix_transformer +from mmcv.cnn import ConvModule +from .modules import num_parallel + + +class MLP(nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.proj(x) + return x + + +class SegFormerHead(nn.Module): + """ + SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + """ + def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): + super(SegFormerHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + #decoder_params = kwargs['decoder_params'] + #embedding_dim = decoder_params['embed_dim'] + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) + self.dropout = nn.Dropout2d(0.1) + + self.linear_fuse = ConvModule( + in_channels=embedding_dim*4, + out_channels=embedding_dim, + kernel_size=1, + norm_cfg=dict(type='BN', requires_grad=True) + ) + + self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) + + def forward(self, x): + c1, c2, c3, c4 = x + + ############## MLP decoder on C1-C4 ########### + n, _, h, w = c4.shape + + _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]).contiguous() + _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]).contiguous() + _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]).contiguous() + _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + + _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]).contiguous() + + _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) + x = self.dropout(_c) + x = self.linear_pred(x) + + return x + + +class UnifiedRepresentationNetworkModDrop(nn.Module): + def __init__(self, backbone, config, num_classes=20, embedding_dim=256, pretrained=True): + super().__init__() + self.num_classes = num_classes + self.embedding_dim = embedding_dim + self.feature_strides = [4, 8, 16, 32] + self.num_parallel = num_parallel + #self.in_channels = [32, 64, 160, 256] + #self.in_channels = [64, 128, 320, 512] + self.encoder0 = getattr(mix_transformer, backbone)() + self.encoder1 = getattr(mix_transformer, backbone)() + self.in_channels = self.encoder0.embed_dims + ## initilize encoder + if pretrained: + state_dict0 = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict0.pop('head.weight') + state_dict0.pop('head.bias') + state_dict0 = expand_state_dict(self.encoder0.state_dict(), state_dict0, self.num_parallel) + self.encoder0.load_state_dict(state_dict0, strict=True) + + state_dict1 = torch.load(config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + state_dict1.pop('head.weight') + state_dict1.pop('head.bias') + state_dict1 = expand_state_dict(self.encoder1.state_dict(), state_dict1, self.num_parallel) + self.encoder1.load_state_dict(state_dict1, strict=True) + print("Load pretrained weights from " + config.root_dir+'/data/pytorch-weight/' + backbone + '.pth') + else: + print("Train from scratch") + + self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, + embedding_dim=self.embedding_dim, num_classes=self.num_classes) + + # self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) + # self.register_parameter('alpha', self.alpha) + + def get_params(self): + param_groups = [[], [], []] + for name, param in list(self.encoder0.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for name, param in list(self.encoder1.named_parameters()): + if "norm" in name: + param_groups[1].append(param) + else: + param_groups[0].append(param) + for param in list(self.decoder.parameters()): + param_groups[2].append(param) + return param_groups + + # def get_params(self): + # param_groups = [[], []] + # for param in list(self.encoder.parameters()): + # param_groups[0].append(param) + # for param in list(self.decoder.parameters()): + # param_groups[1].append(param) + # return param_groups + + def forward(self, data, mask = False, range_batches_to_mask = None, get_sup_loss = False, gt = None, criterion = None): + b, c, h, w = data[0].shape #rgb is the 0th element + x0 = self.encoder0([data[0]])[0] + x1 = self.encoder1([data[1]])[0] + encoded = [] + if not self.training or not mask: + for enc0, enc1 in zip(x0, x1): + enc = (enc0 + enc1) / 2 + # enc = enc1 + encoded.append(enc) + else: + assert range_batches_to_mask[1] == data[0].shape[0] and range_batches_to_mask[0] == 0, "range_batches_to_mask is not configured unless masking all data points" + # masking_branch = torch.tensor([random.choice([0, 1, -1]) for _ in range(data[0].shape[0])]).to(data[0].device) + masking_branch = torch.tensor([random.choice([0, 1]) for _ in range(data[0].shape[0])]).to(data[0].device) + for enc0, enc1 in zip(x0, x1): + index0 = masking_branch == 0 + index1 = masking_branch == 1 + enc = (enc0 + enc1) / 2 + enc[index0] = enc1[index0] #so that mean = enc1 + enc[index1] = enc0[index1] #so that mean = enc0 + encoded.append(enc) + + pred = [self.decoder(encoded)] + for i in range(len(pred)): + pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) + + if not self.training: + return pred + else: # training + if get_sup_loss: + sup_loss = self.get_sup_loss(pred, gt, criterion) + # print(sup_loss, l1, l1_loss, sup_loss + l1_loss, "losses") + return pred, sup_loss + else: + return pred + + def get_sup_loss(self, pred, gt, criterion): + sup_loss = 0 + for p in pred: + p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. + # soft_output = nn.LogSoftmax()(p) + sup_loss += criterion(p, gt) + return sup_loss / len(pred) + + +def expand_state_dict(model_dict, state_dict, num_parallel): + model_dict_keys = model_dict.keys() + state_dict_keys = state_dict.keys() + for model_dict_key in model_dict_keys: + model_dict_key_re = model_dict_key.replace('module.', '') + if model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + for i in range(num_parallel): + ln = '.ln_%d' % i + replace = True if ln in model_dict_key_re else False + model_dict_key_re = model_dict_key_re.replace(ln, '') + if replace and model_dict_key_re in state_dict_keys: + model_dict[model_dict_key] = state_dict[model_dict_key_re] + return model_dict diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7111958154cf9a07a5af853249e6cae2096321b2 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de58a2cfa3d9109a2b73b604feee86469ec3a22 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/constants.cpython-36.pyc b/utils/__pycache__/constants.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66d89a1b45f342a77736e7d0db459835d733340 Binary files /dev/null and b/utils/__pycache__/constants.cpython-36.pyc differ diff --git a/utils/__pycache__/constants.cpython-38.pyc b/utils/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc9f31819d8fd01b48ce6f9653bed30a9422375b Binary files /dev/null and b/utils/__pycache__/constants.cpython-38.pyc differ diff --git a/utils/__pycache__/img_utils.cpython-36.pyc b/utils/__pycache__/img_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb55eaa02d086e1c782c7c4f340308550e7332c3 Binary files /dev/null and b/utils/__pycache__/img_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/img_utils.cpython-38.pyc b/utils/__pycache__/img_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebff051cad7cce043a40c1ee49c54284e56b2db8 Binary files /dev/null and b/utils/__pycache__/img_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/init_utils.cpython-36.pyc b/utils/__pycache__/init_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f977c3bc03e4afa296a4fb64c0cc2e4cd5ce0e Binary files /dev/null and b/utils/__pycache__/init_utils.cpython-36.pyc differ diff --git a/utils/__pycache__/lr_policy.cpython-36.pyc b/utils/__pycache__/lr_policy.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a3cce1773fde29d5e283a10e7b6e77fbc5b927 Binary files /dev/null and b/utils/__pycache__/lr_policy.cpython-36.pyc differ diff --git a/utils/__pycache__/pyt_utils.cpython-36.pyc b/utils/__pycache__/pyt_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abff2de6ec0d19f8debf84773ed48681d78110ef Binary files /dev/null and b/utils/__pycache__/pyt_utils.cpython-36.pyc differ diff --git a/utils/constants.py b/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..2549f963fbdb3f814fe3939affc03190a300b54e --- /dev/null +++ b/utils/constants.py @@ -0,0 +1,5 @@ +import numpy as np + +class Constants: + pytorch_mean = np.array([0.485, 0.456, 0.406]) + pytorch_std = np.array([0.229, 0.224, 0.225]) \ No newline at end of file diff --git a/utils/img_utils.py b/utils/img_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf771bde1a641b5c287e73a37285f8286f475db --- /dev/null +++ b/utils/img_utils.py @@ -0,0 +1,164 @@ +import random +import cv2 +import collections +import numpy as np + +def random_mirror(imgs): + outputs = {} + if random.random() > 0.5: + for key, img in imgs.items(): + if img is not None: + outputs[key] = cv2.flip(img, 1) + else: + outputs[key] = None + else: + outputs = imgs + return outputs + +#DOESN'T HANDLE GT CUTOUT (VALUE SHOULD BE 255 FOR GT) +# def cutout(imgs, imgsize, keys_to_cutout, cutoutsize = 50): +# outputs = {} +# h0 = random.randrange(imgsize[0] - cutoutsize) +# w0 = random.randrange(imgsize[1] - cutoutsize) +# for key, img in imgs.items(): +# if key in keys_to_cutout: +# if img is not None: +# avg = np.mean(img, axis = (0, 1)) +# img[h0:h0+cutoutsize, w0:w0 + cutoutsize] = avg +# outputs[key] = img +# else: +# outputs[key] = None +# return outputs + + +def random_scale(imgs, scale_array, orig_size): + scale = random.choice(scale_array) + sh = int(orig_size[0] * scale) + sw = int(orig_size[1] * scale) + outputs = {} + for key, img in imgs.items(): + if img is not None: + if key == 'rgb': + outputs[key] = resizergb(img, (sw, sh)) + elif key == 'depth': + outputs[key] = resizedepth(img, (sw, sh)) + elif key == 'gt': + outputs[key] = resizegt(img, (sw, sh)) + else: + raise Exception(key, "not supported in random_scale") + else: + outputs[key] = None + return outputs, scale + + +def get_2dshape(shape, *, zero=True): + if not isinstance(shape, collections.Iterable): + shape = int(shape) + shape = (shape, shape) + else: + h, w = map(int, shape) + shape = (h, w) + if zero: + minv = 0 + else: + minv = 1 + + assert min(shape) >= minv, 'invalid shape: {}'.format(shape) + return shape + + +def generate_random_crop_pos(ori_size, crop_size): + ori_size = get_2dshape(ori_size) + h, w = ori_size + + crop_size = get_2dshape(crop_size) + crop_h, crop_w = crop_size + + pos_h, pos_w = 0, 0 + + if h > crop_h: + pos_h = random.randint(0, h - crop_h + 1) + + if w > crop_w: + pos_w = random.randint(0, w - crop_w + 1) + + return pos_h, pos_w + + +def pad_image_to_shape(img, shape, border_mode, value): + # print("enter pad image", img.shape, np.mean(img[:, :, 3]), np.max(img[:, :, 3]), np.mean(img[:, :, 0]), np.max(img[:, :, 0])) + margin = np.zeros(4, np.uint32) + shape = get_2dshape(shape) + pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 + pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 + + margin[0] = pad_height // 2 + margin[1] = pad_height // 2 + pad_height % 2 + margin[2] = pad_width // 2 + margin[3] = pad_width // 2 + pad_width % 2 + + img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], + border_mode, value=value) + + return img, margin + + +def random_crop_pad_to_shape(imgs, img_size, crop_size): + crop_pos = generate_random_crop_pos(img_size, crop_size) + h, w = img_size + start_crop_h, start_crop_w = crop_pos + assert ((start_crop_h < h) and (start_crop_h >= 0)) + assert ((start_crop_w < w) and (start_crop_w >= 0)) + + crop_size = get_2dshape(crop_size) + crop_h, crop_w = crop_size + + outputs = {} + for key, img in imgs.items(): + if img is not None: + img_crop = img[start_crop_h:start_crop_h + crop_h, + start_crop_w:start_crop_w + crop_w, ...] + if key == 'rgb': + pad_label_value = 0 + elif key == 'depth': + pad_label_value = 0 + elif key == 'gt': + pad_label_value = 255 + else: + raise Exception(f"pad_label_value not defined for {key} in random_crop_pad_to_shape") + + img, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT, + pad_label_value) + outputs[key] = img + else: + outputs[key] = None + return outputs, margin + + +def normalize(img, mean, std): + # pytorch pretrained model need the input range: 0-1 + img = img.astype(np.float32) / 255.0 + img = img - mean + img = img / std + + return img + +def normalizedepth(img): + # pytorch pretrained model need the input range: 0-1 + img = img.astype(np.float32) / 255.0 + return img + +def tfnyu_normalizedepth(img): + # pytorch pretrained model need the input range: 0-1 + img = img.astype(np.float32) / 5000. + return img + + +def resizergb(rgb, expectedshape): + return cv2.resize(rgb, expectedshape, interpolation=cv2.INTER_LINEAR) + +def resizedepth(depth, expectedshape): + return cv2.resize(depth, expectedshape, interpolation=cv2.INTER_NEAREST) + +def resizegt(gt, expectedshape): + return cv2.resize(gt, expectedshape, interpolation=cv2.INTER_NEAREST) \ No newline at end of file diff --git a/utils/init_utils.py b/utils/init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf826fdbb832cdc5a3cd6dfa7f41f02fc041920 --- /dev/null +++ b/utils/init_utils.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# @Time : 2018/9/28 δΈ‹εˆ12:13 +# @Author : yuchangqian +# @Contact : changqian_yu@163.com +# @File : init_func.py.py +import math +import warnings +import torch +import torch.nn as nn +from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp + + +def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs): + for name, m in feature.named_modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + conv_init(m.weight, **kwargs) + elif isinstance(m, Conv2_5D_depth): + conv_init(m.weight_0, **kwargs) + conv_init(m.weight_1, **kwargs) + conv_init(m.weight_2, **kwargs) + elif isinstance(m, Conv2_5D_disp): + conv_init(m.weight_0, **kwargs) + conv_init(m.weight_1, **kwargs) + conv_init(m.weight_2, **kwargs) + elif isinstance(m, norm_layer): + m.eps = bn_eps + m.momentum = bn_momentum + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs): + if isinstance(module_list, list): + for feature in module_list: + __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs) + else: + __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, + **kwargs) + + +def group_weight(weight_group, module, norm_layer, lr): + group_decay = [] + group_no_decay = [] + for m in module.modules(): + if isinstance(m, nn.Linear): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + group_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, Conv2_5D_depth): + group_decay.append(m.weight_0) + group_decay.append(m.weight_1) + group_decay.append(m.weight_2) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, Conv2_5D_disp): + group_decay.append(m.weight_0) + group_decay.append(m.weight_1) + group_decay.append(m.weight_2) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ + or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): + if m.weight is not None: + group_no_decay.append(m.weight) + if m.bias is not None: + group_no_decay.append(m.bias) + elif isinstance(m, nn.Parameter): + group_decay.append(m) + elif isinstance(m, nn.Embedding): + group_decay.append(m) + # else: + # print(m, norm_layer) + # print(module.modules) + # print( len(list(module.parameters())) , 'HHHHHHHHHHHHHHHHH', len(group_decay) + len( + # group_no_decay)) + assert len(list(module.parameters())) == len(group_decay) + len( + group_no_decay) + weight_group.append(dict(params=group_decay, lr=lr)) + weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) + return weight_group + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/utils/pyt_utils.py b/utils/pyt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abff37986c3c73c61d73a1cf8e8d1891a8af5f8 --- /dev/null +++ b/utils/pyt_utils.py @@ -0,0 +1,54 @@ +import os +import sys +import time +import random +import argparse +from collections import OrderedDict, defaultdict + +import torch +import torch.utils.model_zoo as model_zoo + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def load_model(model, model_file, is_restore=False): + t_start = time.time() + + if model_file is None: + return model + + if isinstance(model_file, str): + state_dict = torch.load(model_file) + if 'model' in state_dict.keys(): + state_dict = state_dict['model'] + else: + state_dict = model_file + t_ioend = time.time() + + if is_restore: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = 'module.' + k + new_state_dict[name] = v + state_dict = new_state_dict + + + model.load_state_dict(state_dict, strict=False) + ckpt_keys = set(state_dict.keys()) + own_keys = set(model.state_dict().keys()) + missing_keys = own_keys - ckpt_keys + unexpected_keys = ckpt_keys - own_keys + + del state_dict + t_end = time.time() + + return model + + + \ No newline at end of file diff --git a/utils/seg_opr/__pycache__/conv_2_5d.cpython-36.pyc b/utils/seg_opr/__pycache__/conv_2_5d.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e4bd0c0737e011e683f53534681a4815c4d330 Binary files /dev/null and b/utils/seg_opr/__pycache__/conv_2_5d.cpython-36.pyc differ diff --git a/utils/seg_opr/__pycache__/loss_func.cpython-36.pyc b/utils/seg_opr/__pycache__/loss_func.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9bcf731b7c11e019524a7a0cae970aeec8cbdac Binary files /dev/null and b/utils/seg_opr/__pycache__/loss_func.cpython-36.pyc differ diff --git a/utils/seg_opr/__pycache__/lovasz_losses.cpython-36.pyc b/utils/seg_opr/__pycache__/lovasz_losses.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed92543a9c39b9d6250d0f057a0c7dbe98949c88 Binary files /dev/null and b/utils/seg_opr/__pycache__/lovasz_losses.cpython-36.pyc differ diff --git a/utils/seg_opr/__pycache__/metrics.cpython-36.pyc b/utils/seg_opr/__pycache__/metrics.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ec6dfef39652ae8d2115c82839ee06b7f8488ca Binary files /dev/null and b/utils/seg_opr/__pycache__/metrics.cpython-36.pyc differ diff --git a/utils/seg_opr/conv_2_5d.py b/utils/seg_opr/conv_2_5d.py new file mode 100644 index 0000000000000000000000000000000000000000..0c85453cc4d63ce994647c2711a37b6dbee28fd3 --- /dev/null +++ b/utils/seg_opr/conv_2_5d.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2019-03-04 20:52 +# @Author : Jingbo Wang +# @E-mail : wangjingbo1219@foxmail.com & wangjingbo@megvii.com +# @File : conv_2.5d.py +# @Software: PyCharm +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + + +def _ntuple(n): + def parse(x): + if isinstance(x, list) or isinstance(x, tuple): + return x + return tuple([x]*n) + return parse +_pair = _ntuple(2) + + +class Conv2_5D_disp(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, + pixel_size=16): + super(Conv2_5D_disp, self).__init__() + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.kernel_size_prod = self.kernel_size[0] * self.kernel_size[1] + self.stride = stride + self.padding = padding + self.dilation = dilation + self.pixel_size = pixel_size + assert self.kernel_size_prod % 2 == 1 + + self.weight_0 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + self.weight_1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + self.weight_2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + def forward(self, x, disp, camera_params): + N, C, H, W = x.size(0), x.size(1), x.size(2), x.size(3) + out_H = (H + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1 + out_W = (W + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + intrinsic, extrinsic = camera_params['intrinsic'], camera_params['extrinsic'] + + x_col = F.unfold(x, self.kernel_size, dilation=self.dilation, padding=self.padding, + stride=self.stride) # (N, C*kh*kw, out_H*out_W) + x_col = x_col.view(N, C, self.kernel_size_prod, out_H * out_W) + + disp_col = F.unfold(disp, self.kernel_size, dilation=self.dilation, padding=self.padding, + stride=self.stride) # (N, kh*kw, out_H*out_W) + valid_mask = 1 - disp_col.eq(0.).to(torch.float32) + valid_mask *= valid_mask[:, self.kernel_size_prod // 2, :].view(N, 1, out_H * out_W) + disp_col *= valid_mask + depth_col = (extrinsic['baseline'] * intrinsic['fx']).view(N, 1, 1).cuda() / torch.clamp(disp_col, 0.01, 256) + valid_mask = valid_mask.view(N, 1, self.kernel_size_prod, out_H * out_W) + + center_depth = depth_col[:, self.kernel_size_prod // 2, :].view(N, 1, out_H * out_W) + grid_range = self.pixel_size * self.dilation[0] * center_depth / intrinsic['fx'].view(N, 1, 1).cuda() + + mask_0 = torch.abs(depth_col - (center_depth + grid_range)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to( + torch.float32) + mask_1 = torch.abs(depth_col - (center_depth)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to(torch.float32) + mask_1 = (mask_1 + 1 - valid_mask).clamp(min=0., max=1.) + mask_2 = torch.abs(depth_col - (center_depth - grid_range)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to( + torch.float32) + + output = torch.matmul(self.weight_0.view(-1, C * self.kernel_size_prod), + (x_col * mask_0).view(N, C * self.kernel_size_prod, out_H * out_W)) + output += torch.matmul(self.weight_1.view(-1, C * self.kernel_size_prod), + (x_col * mask_1).view(N, C * self.kernel_size_prod, out_H * out_W)) + output += torch.matmul(self.weight_2.view(-1, C * self.kernel_size_prod), + (x_col * mask_2).view(N, C * self.kernel_size_prod, out_H * out_W)) + output = output.view(N, -1, out_H, out_W) + if self.bias: + output += self.bias.view(1, -1, 1, 1) + return output + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.bias is None: + s += ', bias=False' + return s.format(**self.__dict__) + + +class Conv2_5D_depth(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False, + pixel_size=1, is_graph=False): + super(Conv2_5D_depth, self).__init__() + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.kernel_size_prod = self.kernel_size[0] * self.kernel_size[1] + self.stride = stride + self.padding = padding + self.dilation = dilation + self.pixel_size = pixel_size + assert self.kernel_size_prod % 2 == 1 + self.is_graph = is_graph + if self.is_graph: + self.weight_0 = Parameter(torch.Tensor(out_channels, 1, *kernel_size)) + self.weight_1 = Parameter(torch.Tensor(out_channels, 1, *kernel_size)) + self.weight_2 = Parameter(torch.Tensor(out_channels, 1, *kernel_size)) + else: + self.weight_0 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + self.weight_1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + self.weight_2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + def forward(self, x, depth, camera_params): + # if self.is_graph: + # weight_0 = self.weight_0.expand(self.out_channels, self.in_channels, *self.kernel_size).contiguous() + # weight_1 = self.weight_1.expand(self.out_channels, self.in_channels, *self.kernel_size).contiguous() + # weight_2 = self.weight_2.expand(self.out_channels, self.in_channels, *self.kernel_size).contiguous() + # else: + # weight_0 = self.weight_0 + # weight_1 = self.weight_1 + # weight_2 = self.weight_2 + N, C, H, W = x.size(0), x.size(1), x.size(2), x.size(3) + out_H = (H + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1 + out_W = (W + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + intrinsic = camera_params['intrinsic'] + x_col = F.unfold(x, self.kernel_size, dilation=self.dilation, padding=self.padding, + stride=self.stride) # N*(C*kh*kw)*(out_H*out_W) + x_col = x_col.view(N, C, self.kernel_size_prod, out_H * out_W) + depth_col = F.unfold(depth, self.kernel_size, dilation=self.dilation, padding=self.padding, + stride=self.stride) # N*(kh*kw)*(out_H*out_W) + center_depth = depth_col[:, self.kernel_size_prod // 2, :] + #print(depth_col.size()) + center_depth = center_depth.view(N, 1, out_H * out_W) + # grid_range = self.pixel_size * center_depth / (intrinsic['fx'].view(N,1,1) * camera_params['scale'].view(N,1,1)) + grid_range = self.pixel_size * center_depth / intrinsic['fx'].cuda().view(N, 1, 1) + + mask_0 = torch.abs(depth_col - (center_depth + grid_range)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to( + torch.float32) + mask_1 = torch.abs(depth_col - (center_depth)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to(torch.float32) + mask_2 = torch.abs(depth_col - (center_depth - grid_range)).le(grid_range / 2).view(N, 1, self.kernel_size_prod, + out_H * out_W).to( + torch.float32) + output = torch.matmul(self.weight_0.view(-1, C * self.kernel_size_prod), + (x_col * mask_0).view(N, C * self.kernel_size_prod, out_H * out_W)) + output += torch.matmul(self.weight_1.view(-1, C * self.kernel_size_prod), + (x_col * mask_1).view(N, C * self.kernel_size_prod, out_H * out_W)) + output += torch.matmul(self.weight_2.view(-1, C * self.kernel_size_prod), + (x_col * mask_2).view(N, C * self.kernel_size_prod, out_H * out_W)) + output = output.view(N, -1, out_H, out_W) + if self.bias: + output += self.bias.view(1, -1, 1, 1) + return output + + + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.bias is None: + s += ', bias=False' + return s.format(**self.__dict__) diff --git a/utils/seg_opr/loss_func.py b/utils/seg_opr/loss_func.py new file mode 100644 index 0000000000000000000000000000000000000000..bec1b509051679641fa27e5188428656b3394989 --- /dev/null +++ b/utils/seg_opr/loss_func.py @@ -0,0 +1,108 @@ +import numpy as np +import scipy.ndimage as nd + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.seg_opr.lovasz_losses import lovasz_softmax + +class JSD(nn.Module): + def __init__(self): + super(JSD, self).__init__() + self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True) + + def forward(self, p: torch.tensor, q: torch.tensor): + p = F.softmax(p, dim=1) + q = F.softmax(q, dim=1) + m = (0.5 * (p + q)).log() + return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log())) + +class MSE(nn.Module): + def __init__(self): + super(MSE, self).__init__() + self.mse = nn.MSELoss(reduction="mean") + + def forward(self, p: torch.tensor, q: torch.tensor): + p = F.softmax(p, dim=1) + q = F.softmax(q, dim=1) + return self.mse(p, q) + +class ProbOhemCrossEntropy2d(nn.Module): + def __init__(self, ignore_label, reduction='mean', thresh=0.6, min_kept=256, + down_ratio=1, use_weight=False): + super(ProbOhemCrossEntropy2d, self).__init__() + self.ignore_label = ignore_label + self.thresh = float(thresh) + self.min_kept = int(min_kept) + self.down_ratio = down_ratio + if use_weight: + weight = torch.FloatTensor( + [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, + 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, + 1.0865, 1.1529, 1.0507]) + self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction, + weight=weight, + ignore_index=ignore_label) + else: + self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction, + ignore_index=ignore_label) + + def forward(self, pred, target): + b, c, h, w = pred.size() + target = target.view(-1) + valid_mask = target.ne(self.ignore_label) + target = target * valid_mask.long() + num_valid = valid_mask.sum() + + prob = F.softmax(pred, dim=1) + prob = (prob.transpose(0, 1)).reshape(c, -1) + + if self.min_kept > num_valid: + print('Labels: {} < {}'.format(num_valid, self.min_kept)) + elif num_valid > 0: + prob = prob.masked_fill_(~valid_mask, 1) + mask_prob = prob[ + target, torch.arange(len(target), dtype=torch.long)] + threshold = self.thresh + if self.min_kept > 0: + index = mask_prob.argsort() + threshold_index = index[min(len(index), self.min_kept) - 1] + if mask_prob[threshold_index] > self.thresh: + threshold = mask_prob[threshold_index] + kept_mask = mask_prob.le(threshold) # ζ¦‚ηŽ‡ε°δΊŽι˜ˆε€Όηš„ζŒ–ε‡Ίζ₯ (The probability is less than the threshold to be dug out) + target = target * kept_mask.long() + valid_mask = valid_mask * kept_mask + # logger.info('Valid Mask: {}'.format(valid_mask.sum())) + + target = target.masked_fill_(~valid_mask, self.ignore_label) + target = target.view(b, h, w) + + return self.criterion(pred, target) + +class FocalLoss(nn.Module): + def __init__(self, gamma=2, alpha=None, ignore_label=255, size_average=True): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.size_average = size_average + self.CE_loss = nn.CrossEntropyLoss(reduce=False, ignore_index=ignore_label, weight=alpha) + + def forward(self, output, target): + logpt = self.CE_loss(output, target) + pt = torch.exp(-logpt) + loss = ((1-pt)**self.gamma) * logpt + if self.size_average: + return loss.mean() + return loss.sum() + +class LovaszSoftmax(nn.Module): + def __init__(self, classes='present', per_image=False, ignore_index=255): + super(LovaszSoftmax, self).__init__() + self.smooth = classes + self.per_image = per_image + self.ignore_index = ignore_index + + def forward(self, output, target): + logits = F.softmax(output, dim=1) + loss = lovasz_softmax(logits, target, ignore=self.ignore_index) + return loss \ No newline at end of file diff --git a/utils/seg_opr/lovasz_losses.py b/utils/seg_opr/lovasz_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9b4269b915896fe0c24330e46de0c778f185a7 --- /dev/null +++ b/utils/seg_opr/lovasz_losses.py @@ -0,0 +1,250 @@ +""" +Lovasz-Softmax and Jaccard hinge loss in PyTorch +Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) +https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py +""" + +from __future__ import print_function, division + +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse as ifilterfalse + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / float(union) + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / float(union)) + ious.append(iou) + ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + + +# --------------------------- MULTICLASS LOSSES --------------------------- + + +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + +def xloss(logits, labels, ignore=None): + """ + Cross entropy loss + """ + return F.cross_entropy(logits, Variable(labels), ignore_index=255) + + +# --------------------------- HELPER FUNCTIONS --------------------------- +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n \ No newline at end of file diff --git a/utils/seg_opr/metrics.py b/utils/seg_opr/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6f70df0e821958752356684a5954c0815ebd17ce --- /dev/null +++ b/utils/seg_opr/metrics.py @@ -0,0 +1,89 @@ +# encoding: utf-8 + +import numpy as np + +np.seterr(divide='ignore', invalid='ignore') + + +# voc cityscapes metric +def hist_info(n_cl, pred, gt): + assert (pred.shape == gt.shape) + k = (gt >= 0) & (gt < n_cl) + labeled = np.sum(k) + correct = np.sum((pred[k] == gt[k])) + + return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int), + minlength=n_cl ** 2).reshape(n_cl, + n_cl), labeled, correct + + +def compute_score(hist, correct, labeled): + iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + acc = np.diag(hist) / hist.sum(1) + mean_acc = np.nanmean(acc) + mean_IU = np.nanmean(iu) + mean_IU_no_back = np.nanmean(iu[1:]) + freq = hist.sum(1) / hist.sum() + freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() + mean_pixel_acc = correct / labeled + + return iu, mean_IU, mean_IU_no_back, mean_pixel_acc, mean_acc + + +# ade metric +def meanIoU(area_intersection, area_union): + iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1) + meaniou = np.nanmean(iou) + meaniou_no_back = np.nanmean(iou[1:]) + + return iou, meaniou, meaniou_no_back + + +def intersectionAndUnion(imPred, imLab, numClass): + # Remove classes from unlabeled pixels in gt image. + # We should not penalize detections in unlabeled portions of the image. + imPred = imPred * (imLab >= 0) + + # Compute area intersection: + intersection = imPred * (imPred == imLab) + (area_intersection, _) = np.histogram(intersection, bins=numClass, + range=(1, numClass)) + + # Compute area union: + (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) + (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) + area_union = area_pred + area_lab - area_intersection + + return area_intersection, area_union + + +def mean_pixel_accuracy(pixel_correct, pixel_labeled): + mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / ( + np.spacing(1) + np.sum(pixel_labeled)) + + return mean_pixel_accuracy + + +def pixelAccuracy(imPred, imLab): + # Remove classes from unlabeled pixels in gt image. + # We should not penalize detections in unlabeled portions of the image. + pixel_labeled = np.sum(imLab >= 0) + pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) + pixel_accuracy = 1.0 * pixel_correct / pixel_labeled + + return pixel_accuracy, pixel_correct, pixel_labeled + +def compute_metrics(results, num_classes): + hist = np.zeros((num_classes, num_classes)) + correct = 0 + labeled = 0 + count = 0 + for d in results: + hist += d['hist'] + correct += d['correct'] + labeled += d['labeled'] + count += d['count'] + + _, mean_IU, _, mean_pixel_acc, mean_acc = compute_score(hist, correct, + labeled) + return mean_IU, mean_pixel_acc, mean_acc \ No newline at end of file