from utils.unimax_sampler.unimax_sampler import UnimaxSampler language_character_counts = [100, 200, 300, 400, 500] total_character_budget = 1000 num_epochs = 2 # Create the UnimaxSampler. sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs) # Define the expected output. This will depend on your specific implementation of Unimax. expected_output = torch.tensor([0.1, 0.2, 0.3, 0.2, 0.2]) # Use PyTorch's allclose function to compare the computed and expected outputs. # The absolute tolerance parameter atol specifies the maximum difference allowed for the test to pass. self.assertTrue(torch.allclose(sampler.p, expected_output, atol=1e-6))