File size: 7,270 Bytes
d4ebf73 a805467 d4ebf73 b2d39bc d4ebf73 b2d39bc d4ebf73 b2d39bc d4ebf73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import os
import random
import pickle as pkl
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from datasets.preprocessors import RGBDValPre
from utils.constants import Constants as C
class Arguments:
def __init__(self, ratio):
self.ratio = ratio
self.masking_ratio = 1.0
colors = pkl.load(open('./colors.pkl', 'rb'))
args = Arguments(ratio = 0.8)
mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False)
mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth'
mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu')))
mtmodel.eval()
m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False)
m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth'
m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu')))
m3lmodel.eval()
class MaskStudentTeacher(nn.Module):
def __init__(self, student, teacher, ema_alpha, mode = 'train'):
super(MaskStudentTeacher, self).__init__()
self.student = student
self.teacher = teacher
self.teacher = self._detach_teacher(self.teacher)
self.ema_alpha = ema_alpha
self.mode = mode
def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs):
ret = []
if student:
if self.mode == 'train':
ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs))
elif self.mode == 'val':
ret.append(self.student(data, mask = False, **kwargs))
else:
raise Exception('Mode not supported')
if teacher:
ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned
return ret
def _detach_teacher(self, model):
for param in model.parameters():
param.detach_()
return model
def update_teacher_models(self, global_step):
alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
return
def copy_student_to_teacher(self):
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(0).add_(param.data)
return
def get_params(self):
student_params = self.student.get_params()
teacher_params = self.teacher.get_params()
return student_params
def preprocess_data(rgb, depth, dataset_settings):
#RGB: np.array, RGB
#Depth: np.array, minmax normalized, *255
preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
rgb, depth = preprocess(rgb, depth)
if rgb is not None:
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
if depth is not None:
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
return rgb, depth
def visualize(colors, pred, num_classes, dataset_settings):
pred = pred.transpose(1, 2, 0)
predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3))
for i in range(num_classes):
color = colors[i]
predvis = np.where(pred == i, color, predvis)
predvis /= 255.0
predvis = predvis[:,:,::-1]
return predvis
def predict(rgb, depth, check):
dataset_settings = {}
dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540
dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540
rgb, depth = preprocess_data(rgb, depth, dataset_settings)
if rgb is not None:
rgb = rgb.unsqueeze(dim = 0)
if depth is not None:
depth = depth.unsqueeze(dim = 0)
ret = [None, None, './classcolors.png']
if "Mean Teacher" in check:
if rgb is None:
rgb = torch.zeros_like(depth)
if depth is None:
depth = torch.zeros_like(rgb)
scores = mtmodel([rgb, depth])[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[0] = predvis
if "M3L" in check:
mask = False
masking_branch = None
if rgb is None:
mask = True
masking_branch = 0
if depth is None:
mask = True
masking_branch = 1
scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[1] = predvis
return ret
imgs = os.listdir('./examples/rgb')
random.shuffle(imgs)
examples = []
for img in imgs:
examples.append([
'./examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"]
])
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
<center><h2>M3L</h2></center>
<center>Multi-modal teacher for Masked Modality Learning</center>
<br>
<center>Demo to visualize predictions from the Linear Fusion model trained with the vanilla Mean Teacher and the <a href='https://harshm121.github.io/projects/m3l.html'>M3L</a> framework when trained with 0.2% (98) labels. </center>
"""
)
with gr.Row():
rgbinput = gr.Image(label="RGB Input").style(height=256, width=256)
depthinput = gr.Image(label="Depth Input").style(height=256, width=256)
with gr.Row():
modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row():
mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384)
m3loutput = gr.Image(label="M3L Output").style(height=384, width=384)
classnameouptut = gr.Image(label="Classes").style(height=384, width=384)
with gr.Row():
examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck])
with gr.Row():
gr.Markdown(
"""
Read more about [M3L](https://harshm121.github.io/projects/m3l.html)!
"""
)
submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut])
demo.queue(concurrency_count=3)
demo.launch()
|