# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import tempfile import unittest import cv2 import numpy as np import torch from mmengine.structures import InstanceData from mmocr.structures import KIEDataSample from mmocr.utils import bbox2poly from mmocr.visualization import KIELocalVisualizer class TestTextKIELocalVisualizer(unittest.TestCase): def setUp(self): h, w = 12, 10 self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') edge_labels = torch.rand((5, 5)) > 0.5 # gt_instances data_sample = KIEDataSample() gt_instances_data = dict( bboxes=self._rand_bboxes(5, h, w), polygons=self._rand_polys(5, h, w), labels=torch.zeros(5, ), texts=['text1', 'text2', 'text3', 'text4', 'text5'], edge_labels=edge_labels) gt_instances = InstanceData(**gt_instances_data) data_sample.gt_instances = gt_instances pred_instances_data = dict( bboxes=self._rand_bboxes(5, h, w), labels=torch.zeros(5, ), scores=torch.rand((5, )), texts=['text1', 'text2', 'text3', 'text4', 'text5'], edge_labels=edge_labels) pred_instances = InstanceData(**pred_instances_data) data_sample.pred_instances = pred_instances data_sample = data_sample.numpy() self.data_sample = data_sample @staticmethod def _rand_bboxes(num_boxes, h, w): cx, cy, bw, bh = torch.rand(num_boxes, 4).T tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0) tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0) br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0) br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0) bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T return bboxes def _rand_polys(self, num_bboxes, h, w): bboxes = self._rand_bboxes(num_bboxes, h, w) bboxes = bboxes.tolist() polys = [bbox2poly(bbox) for bbox in bboxes] return polys def test_add_datasample(self): image = self.image h, w, c = image.shape visualizer = KIELocalVisualizer(is_openset=True) visualizer.dataset_meta = dict(category=[ dict(id=0, name='bg'), dict(id=1, name='key'), dict(id=2, name='value'), dict(id=3, name='other') ]) visualizer.add_datasample('image', image, self.data_sample) with tempfile.TemporaryDirectory() as tmp_dir: # test out out_file = osp.join(tmp_dir, 'out_file.jpg') visualizer.add_datasample( 'image', image, self.data_sample, out_file=out_file, draw_gt=False, draw_pred=False) self._assert_image_and_shape(out_file, (h, w, c)) visualizer.add_datasample( 'image', image, self.data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h * 2, w * 4, c)) visualizer.add_datasample( 'image', image, self.data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 4, c)) visualizer.add_datasample( 'image', image, self.data_sample, draw_pred=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 4, c)) visualizer = KIELocalVisualizer(is_openset=False) visualizer.dataset_meta = dict(category=[ dict(id=0, name='bg'), dict(id=1, name='key'), dict(id=2, name='value'), dict(id=3, name='other') ]) visualizer.add_datasample( 'image', image, self.data_sample, draw_pred=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 3, c)) def _assert_image_and_shape(self, out_file, out_shape): self.assertTrue(osp.exists(out_file)) drawn_img = cv2.imread(out_file) self.assertTrue(drawn_img.shape == out_shape)