# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This file contains the unit tests for the utils.py file.""" import numpy as np from PIL import Image import torch # pylint: disable=g-bad-import-order from modeling.model import utils def test_scoremap2bbox(): """Test the scoremap2bbox function.""" scoremap = np.zeros((10, 10)) scoremap[1:5, 1:5] = 1 scoremap[5:9, 5:9] = 2 scoremap[5:9, 1:5] = 3 scoremap[1:5, 5:9] = 4 bbox, len_bboxes = utils.scoremap2bbox(scoremap, 0.5) assert len_bboxes == 1 assert bbox[0, 0] == 1 assert bbox[0, 1] == 1 assert bbox[0, 2] == 9 assert bbox[0, 3] == 9 def test_mask2chw(): """Test the mask2chw function.""" mask = np.zeros((10, 10)) mask[1:5, 1:5] = 1 mask[5:9, 5:9] = 2 mask[5:9, 1:5] = 3 mask[1:5, 5:9] = 4 mask = torch.tensor(mask) mask_center, mask_height, mask_width = utils.mask2chw(mask) assert len(mask_center) == 2 assert mask_center[0] == 2 assert mask_center[1] == 2 assert mask_height == 4 assert mask_width == 4 def test_unpad(): """Test the unpad function.""" image = np.zeros((10, 10, 1)) image[1:5, 1:5] = 1 image[5:9, 5:9] = 2 image[5:9, 1:5] = 3 image[1:5, 5:9] = 4 unpad_image = utils.unpad(image, pad=(1, 1, 8, 8)) assert len(unpad_image[0]) == 8, 'The width of the image is not 8.' assert len(unpad_image[1]) == 8, 'The height of the image is not 8.' unpad_image = utils.unpad(image, None) assert (unpad_image == image).sum() == 100 def test_apply_visual_prompts(): """Test the apply_visual_prompts function.""" image = np.ones((5, 5)) mask = np.array([ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1.0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ]) target = np.array([ [1, 1, 255, 1, 1], [1, 255, 1, 255, 1], [255, 1, 1, 1, 255], [1, 255, 1, 255, 1], [1, 1, 255, 1, 1], ]) mask[1:5, 1:5] = 1 prompted_image = utils.apply_visual_prompts( image, mask, visual_prompt_type='circle', thickness=1 ) prompted_array = np.array(prompted_image) assert (prompted_array == target).sum() == 25 def test_reshape_transform(): """Test the reshape_transform function.""" image = torch.zeros((101, 10, 32)) image = utils.reshape_transform(image, height=10, width=10) b, c, h, w = image.shape assert b == 10 assert c == 32 assert h == 10 assert w == 10 def test_img_ms_and_flip(): """Test the img_ms_and_flip function.""" image = np.zeros((120, 150)) image[1:5, 1:5] = 1 image[5:9, 5:9] = 2 image[5:9, 1:5] = 3 image[1:5, 5:9] = 4 image = Image.fromarray(image) image = utils.img_ms_and_flip(image, 120, 150, scales=[1.2], patch_size=16) image = image[0] h, w = image.shape[-2:] assert h == int(np.ceil(1.2 * 120 / 16) * 16) assert w == int(np.ceil(1.2 * 150 / 16) * 16) if __name__ == '__main__': test_scoremap2bbox() test_mask2chw() test_unpad() test_apply_visual_prompts() test_reshape_transform() test_img_ms_and_flip()