M3L / app.py
harshm121's picture
Working demo
d4ebf73
raw
history blame
6.65 kB
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import os
import random
import pickle as pkl
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from datasets.preprocessors import RGBDValPre
from utils.constants import Constants as C
class Arguments:
def __init__(self, ratio):
self.ratio = ratio
self.masking_ratio = 1.0
colors = pkl.load(open('./colors.pkl', 'rb'))
args = Arguments(ratio = 0.8)
mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False)
mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth'
mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu')))
mtmodel.eval()
m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False)
m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth'
m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu')))
m3lmodel.eval()
class MaskStudentTeacher(nn.Module):
def __init__(self, student, teacher, ema_alpha, mode = 'train'):
super(MaskStudentTeacher, self).__init__()
self.student = student
self.teacher = teacher
self.teacher = self._detach_teacher(self.teacher)
self.ema_alpha = ema_alpha
self.mode = mode
def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs):
ret = []
if student:
if self.mode == 'train':
ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs))
elif self.mode == 'val':
ret.append(self.student(data, mask = False, **kwargs))
else:
raise Exception('Mode not supported')
if teacher:
ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned
return ret
def _detach_teacher(self, model):
for param in model.parameters():
param.detach_()
return model
def update_teacher_models(self, global_step):
alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
return
def copy_student_to_teacher(self):
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(0).add_(param.data)
return
def get_params(self):
student_params = self.student.get_params()
teacher_params = self.teacher.get_params()
return student_params
def preprocess_data(rgb, depth, dataset_settings):
#RGB: np.array, RGB
#Depth: np.array, minmax normalized, *255
preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
rgb, depth = preprocess(rgb, depth)
if rgb is not None:
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
if depth is not None:
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
return rgb, depth
def visualize(colors, pred, num_classes, dataset_settings):
pred = pred.transpose(1, 2, 0)
predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3))
for i in range(num_classes):
color = colors[i]
predvis = np.where(pred == i, color, predvis)
predvis /= 255.0
predvis = predvis[:,:,::-1]
return predvis
def predict(rgb, depth, check):
dataset_settings = {}
dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540
dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540
rgb, depth = preprocess_data(rgb, depth, dataset_settings)
if rgb is not None:
rgb = rgb.unsqueeze(dim = 0)
if depth is not None:
depth = depth.unsqueeze(dim = 0)
ret = [None, None, './classcolors.png']
if "Mean Teacher" in check:
if rgb is None:
rgb = torch.zeros_like(depth)
if depth is None:
depth = torch.zeros_like(rgb)
scores = mtmodel([rgb, depth])[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[0] = predvis
if "M3L" in check:
mask = False
masking_branch = None
if rgb is None:
mask = True
masking_branch = 0
if depth is None:
mask = True
masking_branch = 1
scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[1] = predvis
return ret
imgs = os.listdir('./examples/rgb')
random.shuffle(imgs)
examples = []
for img in imgs:
examples.append([
'./examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"]
])
with gr.Blocks() as demo:
with gr.Row():
rgbinput = gr.Image(label="RGB Input").style(height=256, width=256)
depthinput = gr.Image(label="Depth Input").style(height=256, width=256)
with gr.Row():
modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row():
mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384)
m3loutput = gr.Image(label="M3L Output").style(height=384, width=384)
classnameouptut = gr.Image(label="Classes").style(height=384, width=384)
with gr.Row():
examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck])
submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut])
demo.launch()