|
import pytest |
|
import torch |
|
|
|
from mmdet.models.utils import (LearnedPositionalEncoding, |
|
SinePositionalEncoding) |
|
|
|
|
|
def test_sine_positional_encoding(num_feats=16, batch_size=2): |
|
|
|
with pytest.raises(AssertionError): |
|
module = SinePositionalEncoding( |
|
num_feats, scale=(3., ), normalize=True) |
|
|
|
module = SinePositionalEncoding(num_feats) |
|
h, w = 10, 6 |
|
mask = torch.rand(batch_size, h, w) > 0.5 |
|
assert not module.normalize |
|
out = module(mask) |
|
assert out.shape == (batch_size, num_feats * 2, h, w) |
|
|
|
|
|
module = SinePositionalEncoding(num_feats, normalize=True) |
|
assert module.normalize |
|
out = module(mask) |
|
assert out.shape == (batch_size, num_feats * 2, h, w) |
|
|
|
|
|
def test_learned_positional_encoding(num_feats=16, |
|
row_num_embed=10, |
|
col_num_embed=10, |
|
batch_size=2): |
|
module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed) |
|
assert module.row_embed.weight.shape == (row_num_embed, num_feats) |
|
assert module.col_embed.weight.shape == (col_num_embed, num_feats) |
|
h, w = 10, 6 |
|
mask = torch.rand(batch_size, h, w) > 0.5 |
|
out = module(mask) |
|
assert out.shape == (batch_size, num_feats * 2, h, w) |
|
|