import unittest

from PIL import Image

from inference import generate_image


class TestGenerateImage(unittest.TestCase):
    def test_generate_image_output_type(self):
        img = generate_image(image_idx=0, x=0.5, y=0.5)
        self.assertIsInstance(img, Image.Image)

    def test_generate_image_valid_coordinates(self):
        img = generate_image(image_idx=0, x=0.1, y=0.9)
        self.assertIsInstance(img, Image.Image)

    def test_generate_image_edge_coordinates(self):
        img = generate_image(image_idx=1, x=0.0, y=1.0)
        self.assertIsInstance(img, Image.Image)

    def test_generate_image_invalid_image_idx(self):
        with self.assertRaises(KeyError):
            generate_image(image_idx=2, x=0.5, y=0.5)

    def test_generate_image_eps_boundary(self):
        img = generate_image(image_idx=0, x=1e-5, y=1 - 1e-5)
        self.assertIsInstance(img, Image.Image)


if __name__ == "__main__":
    unittest.main()