HakimAiV2 / inference_utils /inference.py
scdrand23's picture
not working version
814a594
import torch
import numpy as np
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
#from utils.visualizer import Visualizer
# from detectron2.utils.colormap import random_color
# from detectron2.data import MetadataCatalog
# from detectron2.structures import BitMasks
from modeling.language.loss import vl_similarity
from utilities.constants import BIOMED_CLASSES
#from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
# import cv2
# import os
# import glob
# import subprocess
from PIL import Image
import random
t = []
t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
transform = transforms.Compose(t)
#metadata = MetadataCatalog.get('coco_2017_train_panoptic')
all_classes = ['background'] + [name.replace('-other','').replace('-merged','')
for name in BIOMED_CLASSES] + ["others"]
# colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
# use color list from matplotlib
import matplotlib.colors as mcolors
colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS)
colors_list = [list(colors.values())[i] for i in range(16)]
from .output_processing import mask_stats, combine_masks
@torch.no_grad()
def interactive_infer_image(model, image, prompts):
image_resize = transform(image)
width = image.size[0]
height = image.size[1]
image_resize = np.asarray(image_resize)
image = torch.from_numpy(image_resize.copy()).permute(2,0,1)
data = {"image": image, 'text': prompts, "height": height, "width": width}
# inistalize task
model.model.task_switch['spatial'] = False
model.model.task_switch['visual'] = False
model.model.task_switch['grounding'] = True
model.model.task_switch['audio'] = False
model.model.task_switch['grounding'] = True
batch_inputs = [data]
results,image_size,extra = model.model.evaluate_demo(batch_inputs)
pred_masks = results['pred_masks'][0]
v_emb = results['pred_captions'][0]
t_emb = extra['grounding_class']
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
matched_id = out_prob.max(0)[1]
pred_masks_pos = pred_masks[matched_id,:,:]
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
# interpolate mask to ori size
pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']),
mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
return pred_mask_prob
# def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None):
# image_ori = transform(image)
# #mask_ori = image['mask']
# width = image_ori.size[0]
# height = image_ori.size[1]
# image_ori = np.asarray(image_ori)
# visual = Visualizer(image_ori, metadata=metadata)
# images = torch.from_numpy(image_ori.copy()).permute(2,0,1)
# data = {"image": images, "height": height, "width": width}
# if len(tasks) == 0:
# tasks = ["Panoptic"]
# # inistalize task
# model.model.task_switch['spatial'] = False
# model.model.task_switch['visual'] = False
# model.model.task_switch['grounding'] = False
# model.model.task_switch['audio'] = False
# # check if reftxt is list of strings
# assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}"
# model.model.task_switch['grounding'] = True
# predicts = {}
# for i, txt in enumerate(reftxt):
# data['text'] = txt
# batch_inputs = [data]
# results,image_size,extra = model.model.evaluate_demo(batch_inputs)
# pred_masks = results['pred_masks'][0]
# v_emb = results['pred_captions'][0]
# t_emb = extra['grounding_class']
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
# temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
# matched_id = out_prob.max(0)[1]
# pred_masks_pos = pred_masks[matched_id,:,:]
# pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
# # interpolate mask to ori size
# #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
# # masks.append(pred_masks_pos[0])
# # mask = pred_masks_pos[0]
# # masks.append(mask)
# # interpolate mask to ori size
# pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
# #pred_masks_pos = 1*(pred_mask_prob > 0.5)
# predicts[txt] = pred_mask_prob[0]
# masks = combine_masks(predicts)
# predict_mask_stats = {}
# print(masks.keys())
# for i, txt in enumerate(masks):
# mask = masks[txt]
# demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt)
# predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori)
# res = demo.get_image()
# torch.cuda.empty_cache()
# # return Image.fromarray(res), stroke_inimg, stroke_refimg
# return Image.fromarray(res), None, predict_mask_stats