File size: 7,406 Bytes
022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 d47736e 022c628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import unittest
from unittest.mock import patch, MagicMock
import os
from PIL import Image
from io import BytesIO
import numpy as np
from handler import EndpointHandler
class TestEndpointHandler(unittest.TestCase):
@patch('handler.RealESRGANer')
@patch('handler.boto3')
def setUp(self, mock_boto3, mock_RealESRGANer):
"""Set up test environment before each test"""
# Set required environment variables
os.environ['TILING_SIZE'] = '0'
os.environ['AWS_ACCESS_KEY_ID'] = 'test_key'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret'
os.environ['S3_BUCKET_NAME'] = 'test-bucket'
self.handler = EndpointHandler()
self.mock_model = mock_RealESRGANer.return_value
self.mock_s3 = mock_boto3.client.return_value
def image_to_bytes(self, image):
"""Helper method to convert PIL Image to bytes"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return buffered.getvalue()
@patch('handler.requests.get')
def test_successful_upscale(self, mock_get):
"""Test successful image upscaling"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
@patch('handler.requests.get')
def test_invalid_outscale(self, mock_get):
"""Test handling of invalid outscale values"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 0.5 # Too small
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Outscale must be between 1 and 10", result["error"])
@patch('handler.requests.get')
def test_download_failure(self, mock_get):
"""Test handling of failed image downloads"""
mock_get.side_effect = Exception("Download failed")
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to download image", result["error"])
@patch('handler.requests.get')
def test_large_image_no_tiling(self, mock_get):
"""Test handling of large images when tiling is disabled"""
# Create an image larger than max_image_size
test_image = Image.new('RGB', (1500, 1500))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Image is too large", result["error"])
@patch('handler.requests.get')
def test_s3_upload_failure(self, mock_get):
"""Test handling of S3 upload failures"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
# Mock S3 upload failure
self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to upload image to s3", result["error"])
def test_missing_image_url(self):
"""Test handling of missing image URL"""
input_data = {
"inputs": {
"outscale": 2
}
}
result = self.handler(input_data)
# Check if result contains all required keys
self.assertIn("image_url", result)
self.assertIn("image_key", result)
self.assertIn("error", result)
# Check if values are as expected
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to get inputs", result["error"])
@patch('handler.requests.get')
def test_grayscale_image(self, mock_get):
"""Test handling of grayscale images"""
test_image = Image.new('L', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
@patch('handler.requests.get')
def test_rgba_image(self, mock_get):
"""Test handling of RGBA images"""
test_image = Image.new('RGBA', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
if __name__ == '__main__':
unittest.main() |