File size: 1,373 Bytes
e26e560 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import pytest
import torch
from mmdet.models.utils import (LearnedPositionalEncoding,
SinePositionalEncoding)
def test_sine_positional_encoding(num_feats=16, batch_size=2):
# test invalid type of scale
with pytest.raises(AssertionError):
module = SinePositionalEncoding(
num_feats, scale=(3., ), normalize=True)
module = SinePositionalEncoding(num_feats)
h, w = 10, 6
mask = torch.rand(batch_size, h, w) > 0.5
assert not module.normalize
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)
# set normalize
module = SinePositionalEncoding(num_feats, normalize=True)
assert module.normalize
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)
def test_learned_positional_encoding(num_feats=16,
row_num_embed=10,
col_num_embed=10,
batch_size=2):
module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed)
assert module.row_embed.weight.shape == (row_num_embed, num_feats)
assert module.col_embed.weight.shape == (col_num_embed, num_feats)
h, w = 10, 6
mask = torch.rand(batch_size, h, w) > 0.5
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)
|