|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import random |
|
import pickle as pkl |
|
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch |
|
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion |
|
from datasets.preprocessors import RGBDValPre |
|
from utils.constants import Constants as C |
|
|
|
class Arguments: |
|
def __init__(self, ratio): |
|
self.ratio = ratio |
|
self.masking_ratio = 1.0 |
|
|
|
colors = pkl.load(open('./colors.pkl', 'rb')) |
|
args = Arguments(ratio = 0.8) |
|
|
|
mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False) |
|
mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth' |
|
mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu'))) |
|
mtmodel.eval() |
|
|
|
m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False) |
|
m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth' |
|
m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu'))) |
|
m3lmodel.eval() |
|
|
|
|
|
|
|
class MaskStudentTeacher(nn.Module): |
|
|
|
def __init__(self, student, teacher, ema_alpha, mode = 'train'): |
|
super(MaskStudentTeacher, self).__init__() |
|
self.student = student |
|
self.teacher = teacher |
|
self.teacher = self._detach_teacher(self.teacher) |
|
self.ema_alpha = ema_alpha |
|
self.mode = mode |
|
def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs): |
|
ret = [] |
|
if student: |
|
if self.mode == 'train': |
|
ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs)) |
|
elif self.mode == 'val': |
|
ret.append(self.student(data, mask = False, **kwargs)) |
|
else: |
|
raise Exception('Mode not supported') |
|
if teacher: |
|
ret.append(self.teacher(data, mask = False, **kwargs)) |
|
return ret |
|
def _detach_teacher(self, model): |
|
for param in model.parameters(): |
|
param.detach_() |
|
return model |
|
def update_teacher_models(self, global_step): |
|
alpha = min(1 - 1 / (global_step + 1), self.ema_alpha) |
|
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): |
|
ema_param.data.mul_(alpha).add_(1 - alpha, param.data) |
|
return |
|
def copy_student_to_teacher(self): |
|
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): |
|
ema_param.data.mul_(0).add_(param.data) |
|
return |
|
def get_params(self): |
|
student_params = self.student.get_params() |
|
teacher_params = self.teacher.get_params() |
|
return student_params |
|
|
|
|
|
def preprocess_data(rgb, depth, dataset_settings): |
|
|
|
|
|
preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
|
rgb, depth = preprocess(rgb, depth) |
|
if rgb is not None: |
|
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float() |
|
if depth is not None: |
|
depth = torch.from_numpy(np.ascontiguousarray(depth)).float() |
|
return rgb, depth |
|
|
|
|
|
def visualize(colors, pred, num_classes, dataset_settings): |
|
pred = pred.transpose(1, 2, 0) |
|
predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3)) |
|
for i in range(num_classes): |
|
color = colors[i] |
|
predvis = np.where(pred == i, color, predvis) |
|
predvis /= 255.0 |
|
predvis = predvis[:,:,::-1] |
|
return predvis |
|
|
|
def predict(rgb, depth, check): |
|
dataset_settings = {} |
|
dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540 |
|
dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540 |
|
|
|
rgb, depth = preprocess_data(rgb, depth, dataset_settings) |
|
if rgb is not None: |
|
rgb = rgb.unsqueeze(dim = 0) |
|
if depth is not None: |
|
depth = depth.unsqueeze(dim = 0) |
|
ret = [None, None, './classcolors.png'] |
|
if "Mean Teacher" in check: |
|
if rgb is None: |
|
rgb = torch.zeros_like(depth) |
|
if depth is None: |
|
depth = torch.zeros_like(rgb) |
|
scores = mtmodel([rgb, depth])[2] |
|
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) |
|
prob = scores.detach() |
|
_, pred = torch.max(prob, dim=1) |
|
pred = pred.numpy() |
|
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) |
|
ret[0] = predvis |
|
if "M3L" in check: |
|
mask = False |
|
masking_branch = None |
|
if rgb is None: |
|
mask = True |
|
masking_branch = 0 |
|
if depth is None: |
|
mask = True |
|
masking_branch = 1 |
|
scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2] |
|
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) |
|
prob = scores.detach() |
|
_, pred = torch.max(prob, dim=1) |
|
pred = pred.numpy() |
|
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) |
|
ret[1] = predvis |
|
|
|
return ret |
|
|
|
imgs = os.listdir('./examples/rgb') |
|
random.shuffle(imgs) |
|
examples = [] |
|
for img in imgs: |
|
examples.append([ |
|
'./examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"] |
|
]) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
rgbinput = gr.Image(label="RGB Input").style(height=256, width=256) |
|
depthinput = gr.Image(label="Depth Input").style(height=256, width=256) |
|
with gr.Row(): |
|
modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:") |
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
with gr.Row(): |
|
mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384) |
|
m3loutput = gr.Image(label="M3L Output").style(height=384, width=384) |
|
classnameouptut = gr.Image(label="Classes").style(height=384, width=384) |
|
with gr.Row(): |
|
examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck]) |
|
submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut]) |
|
|
|
demo.launch() |
|
|