|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains the unit tests for the utils.py file.""" |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
from modeling.model import utils |
|
|
|
|
|
def test_scoremap2bbox(): |
|
"""Test the scoremap2bbox function.""" |
|
scoremap = np.zeros((10, 10)) |
|
scoremap[1:5, 1:5] = 1 |
|
scoremap[5:9, 5:9] = 2 |
|
scoremap[5:9, 1:5] = 3 |
|
scoremap[1:5, 5:9] = 4 |
|
bbox, len_bboxes = utils.scoremap2bbox(scoremap, 0.5) |
|
assert len_bboxes == 1 |
|
assert bbox[0, 0] == 1 |
|
assert bbox[0, 1] == 1 |
|
assert bbox[0, 2] == 9 |
|
assert bbox[0, 3] == 9 |
|
|
|
|
|
def test_mask2chw(): |
|
"""Test the mask2chw function.""" |
|
mask = np.zeros((10, 10)) |
|
mask[1:5, 1:5] = 1 |
|
mask[5:9, 5:9] = 2 |
|
mask[5:9, 1:5] = 3 |
|
mask[1:5, 5:9] = 4 |
|
mask = torch.tensor(mask) |
|
mask_center, mask_height, mask_width = utils.mask2chw(mask) |
|
assert len(mask_center) == 2 |
|
assert mask_center[0] == 2 |
|
assert mask_center[1] == 2 |
|
assert mask_height == 4 |
|
assert mask_width == 4 |
|
|
|
|
|
def test_unpad(): |
|
"""Test the unpad function.""" |
|
image = np.zeros((10, 10, 1)) |
|
image[1:5, 1:5] = 1 |
|
image[5:9, 5:9] = 2 |
|
image[5:9, 1:5] = 3 |
|
image[1:5, 5:9] = 4 |
|
unpad_image = utils.unpad(image, pad=(1, 1, 8, 8)) |
|
assert len(unpad_image[0]) == 8, 'The width of the image is not 8.' |
|
assert len(unpad_image[1]) == 8, 'The height of the image is not 8.' |
|
unpad_image = utils.unpad(image, None) |
|
assert (unpad_image == image).sum() == 100 |
|
|
|
|
|
def test_apply_visual_prompts(): |
|
"""Test the apply_visual_prompts function.""" |
|
image = np.ones((5, 5)) |
|
mask = np.array([ |
|
[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 0, 1.0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
]) |
|
|
|
target = np.array([ |
|
[1, 1, 255, 1, 1], |
|
[1, 255, 1, 255, 1], |
|
[255, 1, 1, 1, 255], |
|
[1, 255, 1, 255, 1], |
|
[1, 1, 255, 1, 1], |
|
]) |
|
mask[1:5, 1:5] = 1 |
|
prompted_image = utils.apply_visual_prompts( |
|
image, mask, visual_prompt_type='circle', thickness=1 |
|
) |
|
prompted_array = np.array(prompted_image) |
|
assert (prompted_array == target).sum() == 25 |
|
|
|
|
|
def test_reshape_transform(): |
|
"""Test the reshape_transform function.""" |
|
image = torch.zeros((101, 10, 32)) |
|
image = utils.reshape_transform(image, height=10, width=10) |
|
b, c, h, w = image.shape |
|
assert b == 10 |
|
assert c == 32 |
|
assert h == 10 |
|
assert w == 10 |
|
|
|
|
|
def test_img_ms_and_flip(): |
|
"""Test the img_ms_and_flip function.""" |
|
image = np.zeros((120, 150)) |
|
image[1:5, 1:5] = 1 |
|
image[5:9, 5:9] = 2 |
|
image[5:9, 1:5] = 3 |
|
image[1:5, 5:9] = 4 |
|
image = Image.fromarray(image) |
|
image = utils.img_ms_and_flip(image, 120, 150, scales=[1.2], patch_size=16) |
|
image = image[0] |
|
h, w = image.shape[-2:] |
|
assert h == int(np.ceil(1.2 * 120 / 16) * 16) |
|
assert w == int(np.ceil(1.2 * 150 / 16) * 16) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_scoremap2bbox() |
|
test_mask2chw() |
|
test_unpad() |
|
test_apply_visual_prompts() |
|
test_reshape_transform() |
|
test_img_ms_and_flip() |
|
|