pandagpt-vicuna-v0-7b / code /pytorchvideo /tests /test_losses_soft_target_cross_entropy.py
mvsoom's picture
Upload folder using huggingface_hub
3133fdb
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import unittest
import torch
import torch.nn.functional as F
from pytorchvideo.losses.soft_target_cross_entropy import SoftTargetCrossEntropyLoss
class TestSoftTargetCrossEntropyLoss(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_soft_target_cross_entropy_loss(self):
"""
Test the soft target cross entropy loss.
"""
for batch_size, num_class, use_1D_target in itertools.product(
(1, 8), (2, 10), (True, False)
):
loss = SoftTargetCrossEntropyLoss()
# Test forwarding.
for (
input_tensor,
target_tensor,
) in TestSoftTargetCrossEntropyLoss._get_inputs(
batch_size=batch_size, num_class=num_class, use_1D_target=use_1D_target
):
output_tensor = loss(input_tensor, target_tensor)
output_shape = output_tensor.shape
self.assertEqual(
output_shape,
torch.Size([]),
"Output shape {} is different from expected.".format(output_shape),
)
# If target is normalized, output_tensor must match direct eval
if target_tensor.ndim == 1 or all(target_tensor.sum(dim=-1) == 1):
_target_tensor = target_tensor
if target_tensor.ndim == 1:
_target_tensor = torch.nn.functional.one_hot(
target_tensor, num_class
)
_output_tensor = torch.sum(
-_target_tensor * F.log_softmax(input_tensor, dim=-1), dim=-1
).mean()
self.assertTrue(abs(_output_tensor - output_tensor) < 1e-6)
@staticmethod
def _get_inputs(
batch_size: int = 16, num_class: int = 400, use_1D_target: bool = True
) -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random tensor as test cases.
if use_1D_target:
target_shape = (batch_size,)
else:
target_shape = (batch_size, num_class)
input_shape = (batch_size, num_class)
yield torch.rand(input_shape), torch.randint(num_class, target_shape)