English
Inference Endpoints
File size: 7,406 Bytes
022c628
 
d47736e
022c628
 
d47736e
022c628
 
 
 
d47736e
 
 
 
 
 
 
 
 
 
022c628
d47736e
022c628
d47736e
 
022c628
 
d47736e
022c628
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
 
d47736e
 
022c628
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
 
 
d47736e
 
022c628
d47736e
 
 
 
 
 
 
 
 
 
 
 
022c628
 
 
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
d47736e
022c628
 
d47736e
 
 
 
 
 
 
 
 
 
022c628
 
d47736e
 
 
022c628
d47736e
 
 
 
 
 
 
 
 
022c628
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
 
 
d47736e
 
022c628
 
d47736e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022c628
 
 
d47736e
 
 
022c628
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import unittest
from unittest.mock import patch, MagicMock
import os
from PIL import Image
from io import BytesIO
import numpy as np
from handler import EndpointHandler

class TestEndpointHandler(unittest.TestCase):
    @patch('handler.RealESRGANer')
    @patch('handler.boto3')
    def setUp(self, mock_boto3, mock_RealESRGANer):
        """Set up test environment before each test"""
        # Set required environment variables
        os.environ['TILING_SIZE'] = '0'
        os.environ['AWS_ACCESS_KEY_ID'] = 'test_key'
        os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret'
        os.environ['S3_BUCKET_NAME'] = 'test-bucket'
        
        self.handler = EndpointHandler()
        self.mock_model = mock_RealESRGANer.return_value
        self.mock_s3 = mock_boto3.client.return_value

    def image_to_bytes(self, image):
        """Helper method to convert PIL Image to bytes"""
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        return buffered.getvalue()

    @patch('handler.requests.get')
    def test_successful_upscale(self, mock_get):
        """Test successful image upscaling"""
        # Create test image and mock response
        test_image = Image.new('RGB', (100, 100))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        # Mock model output
        self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNotNone(result["image_url"])
        self.assertIsNotNone(result["image_key"])
        self.assertIsNone(result["error"])

    @patch('handler.requests.get')
    def test_invalid_outscale(self, mock_get):
        """Test handling of invalid outscale values"""
        # Create test image and mock response
        test_image = Image.new('RGB', (100, 100))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 0.5  # Too small
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNone(result["image_url"])
        self.assertIsNone(result["image_key"])
        self.assertIn("Outscale must be between 1 and 10", result["error"])

    @patch('handler.requests.get')
    def test_download_failure(self, mock_get):
        """Test handling of failed image downloads"""
        mock_get.side_effect = Exception("Download failed")
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNone(result["image_url"])
        self.assertIsNone(result["image_key"])
        self.assertIn("Failed to download image", result["error"])

    @patch('handler.requests.get')
    def test_large_image_no_tiling(self, mock_get):
        """Test handling of large images when tiling is disabled"""
        # Create an image larger than max_image_size
        test_image = Image.new('RGB', (1500, 1500))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNone(result["image_url"])
        self.assertIsNone(result["image_key"])
        self.assertIn("Image is too large", result["error"])

    @patch('handler.requests.get')
    def test_s3_upload_failure(self, mock_get):
        """Test handling of S3 upload failures"""
        # Create test image and mock response
        test_image = Image.new('RGB', (100, 100))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        # Mock model output
        self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
        
        # Mock S3 upload failure
        self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNone(result["image_url"])
        self.assertIsNone(result["image_key"])
        self.assertIn("Failed to upload image to s3", result["error"])

    def test_missing_image_url(self):
        """Test handling of missing image URL"""
        input_data = {
            "inputs": {
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        # Check if result contains all required keys
        self.assertIn("image_url", result)
        self.assertIn("image_key", result)
        self.assertIn("error", result)
        
        # Check if values are as expected
        self.assertIsNone(result["image_url"])
        self.assertIsNone(result["image_key"])
        self.assertIn("Failed to get inputs", result["error"])

    @patch('handler.requests.get')
    def test_grayscale_image(self, mock_get):
        """Test handling of grayscale images"""
        test_image = Image.new('L', (100, 100))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        # Mock model output
        self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None)
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNotNone(result["image_url"])
        self.assertIsNotNone(result["image_key"])
        self.assertIsNone(result["error"])

    @patch('handler.requests.get')
    def test_rgba_image(self, mock_get):
        """Test handling of RGBA images"""
        test_image = Image.new('RGBA', (100, 100))
        mock_response = MagicMock()
        mock_response.content = self.image_to_bytes(test_image)
        mock_get.return_value = mock_response
        
        # Mock model output
        self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None)
        
        input_data = {
            "inputs": {
                "image_url": "http://example.com/test.png",
                "outscale": 2
            }
        }
        
        result = self.handler(input_data)
        
        self.assertIsNotNone(result["image_url"])
        self.assertIsNotNone(result["image_key"])
        self.assertIsNone(result["error"])

if __name__ == '__main__':
    unittest.main()