File size: 9,375 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmpose.models import build_loss


def test_rle_loss():
    # test RLELoss without target weight(default None)
    loss_cfg = dict(type='RLELoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with Q(error) changed to "Gaussian"(default "Laplace")
    loss_cfg = dict(type='RLELoss', q_dis='gaussian')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss._apply(fn)
    loss_cfg = dict(type='RLELoss', size_average=False)
    loss = build_loss(loss_cfg)
    loss.cpu()

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with size_average(default True) changed to False
    loss_cfg = dict(type='RLELoss', size_average=False)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with residual(default True) changed to False
    loss_cfg = dict(type='RLELoss', residual=False)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label)

    # test RLELoss with target weight
    loss_cfg = dict(type='RLELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label, torch.ones_like(fake_label))

    fake_pred = torch.ones((1, 3, 4))
    fake_label = torch.zeros((1, 3, 2))
    loss(fake_pred, fake_label, torch.ones_like(fake_label))


def test_smooth_l1_loss():
    # test SmoothL1Loss without target weight(default None)
    loss_cfg = dict(type='SmoothL1Loss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(.5))

    # test SmoothL1Loss with target weight
    loss_cfg = dict(type='SmoothL1Loss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(.5))


def test_wing_loss():
    # test WingLoss without target weight(default None)
    loss_cfg = dict(type='WingLoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(loss(fake_pred, fake_label), torch.tensor(.5))

    # test WingLoss with target weight
    loss_cfg = dict(type='WingLoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(.5))


def test_soft_wing_loss():
    # test SoftWingLoss without target weight(default None)
    loss_cfg = dict(type='SoftWingLoss')
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(loss(fake_pred, fake_label), torch.tensor(.5))

    # test SoftWingLoss with target weight
    loss_cfg = dict(type='SoftWingLoss', use_target_weight=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.gt(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(.5))


def test_mse_regression_loss():
    # w/o target weight(default None)
    loss_cfg = dict(type='MSELoss')
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(1.))

    # w/ target weight
    loss_cfg = dict(type='MSELoss', use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(
        loss(fake_pred, fake_label, torch.ones_like(fake_label)),
        torch.tensor(1.))


def test_bone_loss():
    # w/o target weight(default None)
    loss_cfg = dict(type='BoneLoss', joint_parents=[0, 0, 1])
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
                             dtype=torch.float32)
    fake_label = fake_pred * 2
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(3**0.5))

    # w/ target weight
    loss_cfg = dict(
        type='BoneLoss', joint_parents=[0, 0, 1], use_target_weight=True)
    loss = build_loss(loss_cfg)
    fake_pred = torch.zeros((1, 3, 3))
    fake_label = torch.zeros((1, 3, 3))
    fake_weight = torch.ones((1, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_weight), torch.tensor(0.))

    fake_pred = torch.tensor([[[0, 0, 0], [1, 1, 1], [2, 2, 2]]],
                             dtype=torch.float32)
    fake_label = fake_pred * 2
    fake_weight = torch.ones((1, 2))
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_weight), torch.tensor(3**0.5))


def test_semi_supervision_loss():
    loss_cfg = dict(
        type='SemiSupervisionLoss',
        joint_parents=[0, 0, 1],
        warmup_iterations=1)
    loss = build_loss(loss_cfg)

    unlabeled_pose = torch.rand((1, 3, 3))
    unlabeled_traj = torch.ones((1, 1, 3))
    labeled_pose = unlabeled_pose.clone()
    fake_pred = dict(
        labeled_pose=labeled_pose,
        unlabeled_pose=unlabeled_pose,
        unlabeled_traj=unlabeled_traj)

    intrinsics = torch.tensor([[1, 1, 1, 1, 0.1, 0.1, 0.1, 0, 0]],
                              dtype=torch.float32)
    unlabled_target_2d = loss.project_joints(unlabeled_pose + unlabeled_traj,
                                             intrinsics)
    fake_label = dict(
        unlabeled_target_2d=unlabled_target_2d, intrinsics=intrinsics)

    # test warmup
    losses = loss(fake_pred, fake_label)
    assert not losses

    # test semi-supervised loss
    losses = loss(fake_pred, fake_label)
    assert torch.allclose(losses['proj_loss'], torch.tensor(0.))
    assert torch.allclose(losses['bone_loss'], torch.tensor(0.))


def test_soft_weight_smooth_l1_loss():
    loss_cfg = dict(
        type='SoftWeightSmoothL1Loss', use_target_weight=False, beta=0.5)
    loss = build_loss(loss_cfg)

    fake_pred = torch.zeros((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.))

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    assert torch.allclose(loss(fake_pred, fake_label), torch.tensor(.75))

    loss_cfg = dict(
        type='SoftWeightSmoothL1Loss',
        use_target_weight=True,
        supervise_empty=True)
    loss = build_loss(loss_cfg)

    fake_pred = torch.ones((1, 3, 2))
    fake_label = torch.zeros((1, 3, 2))
    fake_weight = torch.arange(6).reshape(1, 3, 2).float()
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_weight), torch.tensor(1.25))

    loss_cfg = dict(
        type='SoftWeightSmoothL1Loss',
        use_target_weight=True,
        supervise_empty=False)
    loss = build_loss(loss_cfg)
    assert torch.allclose(
        loss(fake_pred, fake_label, fake_weight), torch.tensor(1.5))

    with pytest.raises(ValueError):
        _ = loss.smooth_l1_loss(fake_pred, fake_label, reduction='fake')

    output = loss.smooth_l1_loss(fake_pred, fake_label, reduction='sum')
    assert torch.allclose(output, torch.tensor(3.0))