show / mmpose-0.29.0 /tests /test_losses /test_bottom_up_losses.py
camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
5.89 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
def test_multi_loss_factory():
from mmpose.models import build_loss
# test heatmap loss
loss_cfg = dict(type='HeatmapLoss')
loss = build_loss(loss_cfg)
with pytest.raises(AssertionError):
fake_pred = torch.zeros((2, 3, 64, 64))
fake_label = torch.zeros((1, 3, 64, 64))
fake_mask = torch.zeros((1, 64, 64))
loss(fake_pred, fake_label, fake_mask)
fake_pred = torch.zeros((1, 3, 64, 64))
fake_label = torch.zeros((1, 3, 64, 64))
fake_mask = torch.zeros((1, 64, 64))
assert torch.allclose(
loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))
fake_pred = torch.ones((1, 3, 64, 64))
fake_label = torch.zeros((1, 3, 64, 64))
fake_mask = torch.zeros((1, 64, 64))
assert torch.allclose(
loss(fake_pred, fake_label, fake_mask), torch.tensor(0.))
fake_pred = torch.ones((1, 3, 64, 64))
fake_label = torch.zeros((1, 3, 64, 64))
fake_mask = torch.ones((1, 64, 64))
assert torch.allclose(
loss(fake_pred, fake_label, fake_mask), torch.tensor(1.))
# test AE loss
fake_tags = torch.zeros((1, 18, 1))
fake_joints = torch.zeros((1, 3, 2, 2), dtype=torch.int)
loss_cfg = dict(type='AELoss', loss_type='exp')
loss = build_loss(loss_cfg)
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))
fake_tags[0, 0, 0] = 1.
fake_tags[0, 10, 0] = 0.
fake_joints[0, 0, 0, :] = torch.IntTensor((0, 1))
fake_joints[0, 0, 1, :] = torch.IntTensor((10, 1))
loss_cfg = dict(type='AELoss', loss_type='exp')
loss = build_loss(loss_cfg)
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.25))
fake_tags[0, 0, 0] = 0
fake_tags[0, 7, 0] = 1.
fake_tags[0, 17, 0] = 1.
fake_joints[0, 1, 0, :] = torch.IntTensor((7, 1))
fake_joints[0, 1, 1, :] = torch.IntTensor((17, 1))
loss_cfg = dict(type='AELoss', loss_type='exp')
loss = build_loss(loss_cfg)
assert torch.allclose(loss(fake_tags, fake_joints)[1], torch.tensor(0.))
loss_cfg = dict(type='AELoss', loss_type='max')
loss = build_loss(loss_cfg)
assert torch.allclose(loss(fake_tags, fake_joints)[0], torch.tensor(0.))
with pytest.raises(ValueError):
loss_cfg = dict(type='AELoss', loss_type='min')
loss = build_loss(loss_cfg)
loss(fake_tags, fake_joints)
# test MultiLossFactory
with pytest.raises(AssertionError):
loss_cfg = dict(
type='MultiLossFactory',
num_joints=2,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=True,
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(
type='MultiLossFactory',
num_joints=2,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[True],
push_loss_factor=0.001,
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(
type='MultiLossFactory',
num_joints=2,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[True],
push_loss_factor=[0.001],
pull_loss_factor=0.001,
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(
type='MultiLossFactory',
num_joints=2,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[True],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=True,
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
with pytest.raises(AssertionError):
loss_cfg = dict(
type='MultiLossFactory',
num_joints=2,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[True],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=1.0)
loss = build_loss(loss_cfg)
loss_cfg = dict(
type='MultiLossFactory',
num_joints=17,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[False],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[False],
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
fake_outputs = [torch.zeros((1, 34, 64, 64))]
fake_heatmaps = [torch.zeros((1, 17, 64, 64))]
fake_masks = [torch.ones((1, 64, 64))]
fake_joints = [torch.zeros((1, 30, 17, 2))]
heatmaps_losses, push_losses, pull_losses = \
loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
assert heatmaps_losses == [None]
assert pull_losses == [None]
assert push_losses == [None]
loss_cfg = dict(
type='MultiLossFactory',
num_joints=17,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[True],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0])
loss = build_loss(loss_cfg)
heatmaps_losses, push_losses, pull_losses = \
loss(fake_outputs, fake_heatmaps, fake_masks, fake_joints)
assert len(heatmaps_losses) == 1