English
Inference Endpoints
garg-aayush commited on
Commit
022c628
·
1 Parent(s): 6ead77d

add unit tests, create endpoint test notebook

Browse files
Files changed (2) hide show
  1. test_handler.ipynb +128 -0
  2. unit_tests.py +95 -0
test_handler.ipynb ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/usr/local/lib/python3.10/dist-packages/torchvision/transforms/functional_tensor.py:5: UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be **removed in 0.17**. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in torchvision.transforms.v2.functional.\n",
13
+ " warnings.warn(\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from handler import EndpointHandler\n",
19
+ "import base64\n",
20
+ "from io import BytesIO\n",
21
+ "from PIL import Image\n",
22
+ "import cv2\n",
23
+ "import random\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# helper decoder\n",
33
+ "def decode_base64_image(image_string):\n",
34
+ " base64_image = base64.b64decode(image_string)\n",
35
+ " buffer = BytesIO(base64_image)\n",
36
+ " return Image.open(buffer)"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# init handler\n",
46
+ "my_handler = EndpointHandler(path=\".\")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 6,
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "image.size: (1200, 517), image.mode: RGBA, outscale: 10.0\n"
59
+ ]
60
+ },
61
+ {
62
+ "name": "stdout",
63
+ "output_type": "stream",
64
+ "text": [
65
+ "output.shape: (5170, 12000, 4)\n",
66
+ "out_image.size: (12000, 5170)\n",
67
+ "image.size: (1056, 1068), image.mode: RGB, outscale: 3.0\n",
68
+ "output.shape: (3204, 3168, 3)\n",
69
+ "out_image.size: (3168, 3204)\n",
70
+ "image.size: (1056, 1068), image.mode: L, outscale: 5.49\n",
71
+ "output.shape: (5863, 5797, 3)\n",
72
+ "out_image.size: (5797, 5863)\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "img_dir = \"test_data/\"\n",
78
+ "img_names = [\"4121783.png\", \"FB_IMG_1725931665635.jpg\", \"FB_IMG_1725931665635_gray.jpg\"]\n",
79
+ "out_scales = [10, 3, 5.49]\n",
80
+ "for img_name, outscale in zip(img_names, out_scales):\n",
81
+ " image_path = img_dir + img_name\n",
82
+ " # create payload\n",
83
+ " with open(image_path, \"rb\") as i:\n",
84
+ " b64 = base64.b64encode(i.read())\n",
85
+ " b64 = b64.decode(\"utf-8\")\n",
86
+ " payload = {\n",
87
+ " \"inputs\": {\"image\": b64, \n",
88
+ " \"outscale\": outscale\n",
89
+ " }\n",
90
+ " }\n",
91
+ "\n",
92
+ "\n",
93
+ " output_payload = my_handler(payload)\n",
94
+ " out_image = decode_base64_image(output_payload[\"out_image\"])\n",
95
+ " print(f\"out_image.size: {out_image.size}\")\n",
96
+ " out_image.save(f\"test_data/outputs/{img_name.split('.')[0]}_outscale_{outscale}.png\")\n"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": []
105
+ }
106
+ ],
107
+ "metadata": {
108
+ "kernelspec": {
109
+ "display_name": "Python 3",
110
+ "language": "python",
111
+ "name": "python3"
112
+ },
113
+ "language_info": {
114
+ "codemirror_mode": {
115
+ "name": "ipython",
116
+ "version": 3
117
+ },
118
+ "file_extension": ".py",
119
+ "mimetype": "text/x-python",
120
+ "name": "python",
121
+ "nbconvert_exporter": "python",
122
+ "pygments_lexer": "ipython3",
123
+ "version": "3.10.12"
124
+ }
125
+ },
126
+ "nbformat": 4,
127
+ "nbformat_minor": 2
128
+ }
unit_tests.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch, MagicMock
3
+ from PIL import Image
4
+ import base64
5
+ import numpy as np
6
+ from io import BytesIO
7
+ from handler import EndpointHandler
8
+
9
+ class TestEndpointHandler(unittest.TestCase):
10
+
11
+ @patch('handler.RealESRGANer')
12
+ def setUp(self, mock_RealESRGANer):
13
+ self.handler = EndpointHandler(path=".")
14
+ self.mock_model = mock_RealESRGANer.return_value
15
+
16
+ def create_test_image(self, mode='RGB', size=(100, 100)):
17
+ image = Image.new(mode, size)
18
+ buffered = BytesIO()
19
+ image.save(buffered, format="PNG")
20
+ return base64.b64encode(buffered.getvalue()).decode()
21
+
22
+ def get_svg_image(self):
23
+ test_image = "test_data/834989.svg"
24
+ return test_image
25
+
26
+ def test_float_outscale(self):
27
+ test_image = self.create_test_image()
28
+ input_data = {"inputs": {"image": test_image, "outscale": 2.5}}
29
+
30
+ self.mock_model.enhance.return_value = (np.zeros((250, 250, 3), dtype=np.uint8), None)
31
+ result = self.handler(input_data)
32
+
33
+ self.assertIn("out_image", result)
34
+ self.assertIsNone(result["error"])
35
+
36
+ def test_outscale_too_small(self):
37
+ test_image = self.create_test_image()
38
+ input_data = {"inputs": {"image": test_image, "outscale": 0.5}}
39
+
40
+ result = self.handler(input_data)
41
+
42
+ self.assertIsNone(result["out_image"])
43
+ self.assertIn("Outscale must be between 1 and 10", result["error"])
44
+
45
+ def test_outscale_too_large(self):
46
+ test_image = self.create_test_image()
47
+ input_data = {"inputs": {"image": test_image, "outscale": 11}}
48
+
49
+ result = self.handler(input_data)
50
+
51
+ self.assertIsNone(result["out_image"])
52
+ self.assertIn("Outscale must be between 1 and 10", result["error"])
53
+
54
+ def test_valid_rgb_image(self):
55
+ test_image = self.create_test_image()
56
+ input_data = {"inputs": {"image": test_image, "outscale": 2}}
57
+
58
+ self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
59
+
60
+ result = self.handler(input_data)
61
+
62
+ self.assertIn("out_image", result)
63
+ self.assertIsNone(result["error"])
64
+ self.mock_model.enhance.assert_called_once()
65
+
66
+ def test_valid_rgba_image(self):
67
+ test_image = self.create_test_image(mode='RGBA')
68
+ input_data = {"inputs": {"image": test_image, "outscale": 2}}
69
+
70
+ self.mock_model.enhance.return_value = (np.zeros((400, 400, 4), dtype=np.uint8), None)
71
+
72
+ result = self.handler(input_data)
73
+
74
+ self.assertIn("out_image", result)
75
+ self.assertIsNone(result["error"])
76
+
77
+ def test_image_too_large(self):
78
+ test_image = self.create_test_image(size=(1500, 1500))
79
+ input_data = {"inputs": {"image": test_image}}
80
+
81
+ result = self.handler(input_data)
82
+
83
+ self.assertIsNone(result["out_image"])
84
+ self.assertIn("Image is too large", result["error"])
85
+
86
+ def test_missing_image_key(self):
87
+ input_data = {"inputs": {}}
88
+
89
+ result = self.handler(input_data)
90
+
91
+ self.assertIsNone(result["out_image"])
92
+ self.assertIn("Missing key", result["error"])
93
+
94
+ if __name__ == '__main__':
95
+ unittest.main()