English
Inference Endpoints
garg-aayush commited on
Commit
926a1ef
·
1 Parent(s): aa93695

update the handler.py and test_handler file for bugs and error fixes

Browse files
Files changed (2) hide show
  1. handler.py +221 -146
  2. test_handler.ipynb +14 -3
handler.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
  from io import BytesIO
@@ -15,194 +25,259 @@ import torch
15
  import base64
16
  import requests
17
  import logging
 
18
 
19
 
20
  class EndpointHandler:
21
  def __init__(self, path=""):
 
 
 
 
 
22
 
23
- self.tiling_size = int(os.environ["TILING_SIZE"])
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Initialize the Real-ESRGAN model with specified parameters
26
- self.model = RealESRGANer(
27
- scale=4, # Scale factor for the model
28
- # Path to the pre-trained model weights
29
- model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
30
- # model_path=f"/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth",
31
- # Initialize the RRDBNet model architecture with specified parameters
32
- model= RRDBNet(num_in_ch=3,
 
 
33
  num_out_ch=3,
34
  num_feat=64,
35
  num_block=23,
36
  num_grow_ch=32,
37
  scale=4
38
  ),
39
- tile=self.tiling_size,
40
- tile_pad=0,
41
- half=True,
42
- )
 
 
 
 
43
 
44
- # Initialize the S3 client with AWS credentials from environment variables
45
- self.s3 = boto3.client('s3',
 
 
46
  aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
47
  aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
48
  )
49
- # Get the S3 bucket name from environment variables
50
- self.bucket_name = os.environ["S3_BUCKET_NAME"]
 
 
 
51
 
52
- # get the logging level from environment variables
53
- logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
54
- self.logger = logging.getLogger(__name__)
55
 
56
 
57
  def __call__(self, data: Any) -> Dict[str, List[float]]:
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
- ############################################################
61
- # get inputs and download image
62
- ############################################################
63
- self.logger.info(">>> 1/7: GETTING INPUTS....")
64
  inputs = data.pop("inputs", data)
65
-
66
- # get outscale
67
  outscale = float(inputs.pop("outscale", 3))
68
  self.logger.info(f"outscale: {outscale}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # download image
71
- try:
72
- self.logger.info(f"downloading image from URL: {inputs['image_url']}")
73
- image = self.download_image_url(inputs['image_url'])
74
- except Exception as e:
75
- logging.error(f"Error downloading image from URL: {inputs['image_url']}. Exception: {e}")
76
- return {"out_image": None, "error": f"Failed to download image: {e}"}
 
 
 
 
 
 
77
 
78
-
79
- ############################################################
80
- # run assertions
81
- ############################################################
82
- self.logger.info(">>> 2/7: RUNNING ASSERTIONS ON IMAGE....")
83
-
84
- # get image size and mode
85
- in_size, in_mode = image.size, image.mode
86
- self.logger.info(f"image.size: {image.size}, image.mode: {image.mode}")
87
-
88
- # check image size and mode and return dict
89
- try:
90
- assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
91
- if self.tiling_size == 0:
92
- assert in_size[0] * in_size[1] < 1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {self.tiling_size*self.tiling_size}"
93
- assert outscale > 1 and outscale <= 10, f"Outscale must be between 1 and 10: {outscale}"
94
- except AssertionError as e:
95
- self.logger.error(f"Assertion error: {e}")
96
- return {"out_image": None, "error": str(e)}
97
-
98
-
99
- ############################################################
100
- # Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
101
- ############################################################
102
- self.logger.info(f">>> 3/7: CONVERTING IMAGE TO OPENCV BGR/BGRA FORMAT....")
103
- try:
104
- opencv_image = np.array(image)
105
- except Exception as e:
106
- self.logger.error(f"Error converting image to opencv format: {e}")
107
- return {"out_image": None, "error": f"Failed to convert image to opencv format: {e}"}
108
-
109
- # convert image to BGR
110
- if in_mode == "RGB":
111
- self.logger.info(f"converting RGB image to BGR")
112
- opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
113
- elif in_mode == "RGBA":
114
- self.logger.info(f"converting RGBA image to BGRA")
115
- opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA)
116
- elif in_mode == "L":
117
- self.logger.info(f"converting grayscale image to BGR")
118
- opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB)
119
- else:
120
- self.logger.error(f"Unsupported image mode: {in_mode}")
121
- return {"out_image": None, "error": f"Unsupported image mode: {in_mode}"}
122
-
123
-
124
- ############################################################
125
- # upscale image
126
- ############################################################
127
- self.logger.info(f">>> 4/7: UPSCALING IMAGE....")
128
-
129
- try:
130
- output, _ = self.model.enhance(opencv_image, outscale=outscale)
131
- except Exception as e:
132
- self.logger.error(f"Error enhancing image: {e}")
133
- return {"out_image": None, "error": "Image enhancement failed."}
134
- # debug
135
- self.logger.info(f"output.shape: {output.shape}")
136
-
137
-
138
- ############################################################
139
- # convert to RGB/RGBA format
140
- ############################################################
141
- self.logger.info(f">>> 5/7: CONVERTING IMAGE TO RGB/RGBA FORMAT....")
142
- out_shape = output.shape
143
- if len(out_shape) == 3:
144
- if out_shape[2] == 3:
145
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
146
- elif out_shape[2] == 4:
147
- output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
148
- else:
149
- output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
150
-
151
-
152
- ############################################################
153
- # convert to PIL image
154
- ############################################################
155
- self.logger.info(f">>> 6/7: CONVERTING IMAGE TO PIL....")
156
- try:
157
- img_byte_arr = BytesIO()
158
- output = Image.fromarray(output)
159
- except Exception as e:
160
- self.logger.error(f"Error converting upscaled image to PIL: {e}")
161
- return {"out_image": None, "error": f"Failed to convert upscaled image to PIL: {e}"}
162
-
163
-
164
- ############################################################
165
- # upload to s3
166
- ############################################################
167
- self.logger.info(f">>> 7/7: UPLOADING IMAGE TO S3....")
168
- try:
169
- image_url, key = self.upload_to_s3(output)
170
- self.logger.info(f"image uploaded to s3: {image_url}")
171
- except Exception as e:
172
- self.logger.error(f"Error uploading image to s3: {e}")
173
- return {"out_image": None, "error": f"Failed to upload image to s3: {e}"}
174
-
175
- return {"image_url": image_url,
176
- "image_key": key,
177
- "error": None
178
- }
179
 
180
- # handle unexpected errors
 
 
 
 
 
 
 
181
  except Exception as e:
182
- self.logger.error(f"An unexpected error occurred: {e}")
183
- return {"out_image": None, "error": f"An unexpected error occurred: {e}"}
 
 
 
 
 
 
 
184
 
 
 
 
 
 
 
185
 
186
- def upload_to_s3(self, image):
187
- "Upload the image to s3 and return the url."
 
188
 
189
  prefix = str(uuid.uuid4())
190
  # Save the image to an in-memory file
191
  in_mem_file = io.BytesIO()
192
- image.save(in_mem_file, 'PNG')
193
  in_mem_file.seek(0)
194
 
195
- # Upload the image to s3
196
  key = f"{prefix}.png"
197
  self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key)
198
- image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key)
199
 
200
- # return the url and the key
201
  return image_url, key
202
 
203
- def download_image_url(self, image_url):
204
- "Download the image from the url and return the image."
205
-
 
 
 
 
 
 
 
206
  response = requests.get(image_url)
207
  image = Image.open(BytesIO(response.content))
208
  return image
 
1
+ """
2
+ This module handles the endpoint for image upscaling using the Real-ESRGAN model.
3
+
4
+ Required Environment Variables:
5
+ - TILING_SIZE: The size of the tiles for processing images. Set to 0 to disable tiling.
6
+ - AWS_ACCESS_KEY_ID: AWS access key for S3 access.
7
+ - AWS_SECRET_ACCESS_KEY: AWS secret key for S3 access.
8
+ - BUCKET_NAME: The name of the S3 bucket where images will be uploaded.
9
+
10
+ """
11
  import torch
12
  from PIL import Image
13
  from io import BytesIO
 
25
  import base64
26
  import requests
27
  import logging
28
+ import time
29
 
30
 
31
  class EndpointHandler:
32
  def __init__(self, path=""):
33
+ """
34
+ Initializes the EndpointHandler class, setting up the Real-ESRGAN model and S3 client.
35
+
36
+ Args:
37
+ path (str): Optional path to the model weights. Defaults to an empty string.
38
 
39
+ This constructor performs the following actions:
40
+ - Configures logging based on environment variables.
41
+ - Retrieves the tiling size from environment variables.
42
+ - Initializes the Real-ESRGAN model with specified parameters, including scale, model path, and architecture.
43
+ - Sets up the S3 client using AWS credentials from environment variables.
44
+ - Logs the initialization process and any errors encountered during setup.
45
+ """
46
+
47
+ # get the logging level from environment variables
48
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
49
+ self.logger = logging.getLogger(__name__)
50
+
51
 
52
+ self.tiling_size = int(os.environ["TILING_SIZE"])
53
+ # self.model_path = f"/repository/weights/Real-ESRGAN-x4plus.pth"
54
+ self.max_image_size = 1400 * 1400
55
+ self.model_path = f"/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth"
56
+
57
+
58
+ # log model path and tiling size
59
+ self.logger.info(f"model_path: {self.model_path}")
60
+ if self.tiling_size == 0: self.logger.info("TILING_SIZE is 0, not using tiling")
61
+ else: self.logger.info(f"TILING_SIZE is {self.tiling_size}, using tiling")
62
+
63
+
64
  # Initialize the Real-ESRGAN model with specified parameters
65
+ start_time = time.time()
66
+ self.logger.info(f"initializing model")
67
+ try:
68
+ self.model = RealESRGANer(
69
+ scale=4, # Scale factor for the model
70
+ # Path to the pre-trained model weights
71
+ model_path=self.model_path,
72
+ # Initialize the RRDBNet model architecture with specified parameters
73
+ model= RRDBNet(num_in_ch=3,
74
  num_out_ch=3,
75
  num_feat=64,
76
  num_block=23,
77
  num_grow_ch=32,
78
  scale=4
79
  ),
80
+ tile=self.tiling_size,
81
+ tile_pad=0,
82
+ half=True,
83
+ )
84
+ self.logger.info(f"model initialized in {time.time() - start_time} seconds")
85
+ except Exception as e:
86
+ self.logger.error(f"Error initializing model: {e}")
87
+ raise e
88
 
89
+
90
+ try:
91
+ # Initialize the S3 client with AWS credentials from environment variables
92
+ self.s3 = boto3.client('s3',
93
  aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
94
  aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
95
  )
96
+ # Get the S3 bucket name from environment variables
97
+ self.bucket_name = os.environ["S3_BUCKET_NAME"]
98
+ except Exception as e:
99
+ self.logger.error(f"Error initializing S3 client: {e}")
100
+ raise e
101
 
 
 
 
102
 
103
 
104
  def __call__(self, data: Any) -> Dict[str, List[float]]:
105
+ """
106
+ Processes the input data to upscale an image using the Real-ESRGAN model.
107
+
108
+ Args:
109
+ data (Any): A dictionary containing the input data. It should include:
110
+ - 'inputs': A dictionary with the following keys:
111
+ - 'image_url' (str): The URL of the image to be upscaled.
112
+ - 'outscale' (float): The scaling factor for the upscaling process.
113
+
114
+ Returns:
115
+ Dict[str, List[float]]: A dictionary containing the results of the upscaling process, which includes:
116
+ - 'image_url' (str | None): The URL of the upscaled image or None if an error occurred.
117
+ - 'image_key' (str | None): The key for the uploaded image in S3 or None if an error occurred.
118
+ - 'error' (str | None): An error message if an error occurred, otherwise None.
119
+ """
120
+
121
+ ############################################################
122
+ # get inputs and download image
123
+ ############################################################
124
+ self.logger.info(">>> 1/7: GETTING INPUTS....")
125
  try:
 
 
 
 
126
  inputs = data.pop("inputs", data)
 
 
127
  outscale = float(inputs.pop("outscale", 3))
128
  self.logger.info(f"outscale: {outscale}")
129
+ image_url = inputs["image_url"]
130
+ except Exception as e:
131
+ self.logger.error(f"Error getting inputs: {e}")
132
+ return {"image_url": None, "image_key": None, "error": f"Failed to get inputs: {e}"}
133
+
134
+ # download image
135
+ try:
136
+ self.logger.info(f"downloading image from URL: {image_url}")
137
+ image = self.download_image_url(image_url)
138
+ except Exception as e:
139
+ self.logger.error(f"Error downloading image from URL: {image_url}. Exception: {e}")
140
+ return {"image_url": None, "image_key": None, "error": f"Failed to download image: {e}"}
141
+
142
+
143
+ ############################################################
144
+ # run assertions
145
+ ############################################################
146
+ self.logger.info(">>> 2/7: RUNNING ASSERTIONS ON IMAGE....")
147
+
148
+ # get image size and mode
149
+ in_size, in_mode = image.size, image.mode
150
+ self.logger.info(f"image.size: {image.size}, image.mode: {image.mode}")
151
+
152
+ # check image size and mode and return dict
153
+ try:
154
+ assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
155
+ if self.tiling_size == 0:
156
+ assert in_size[0] * in_size[1] < self.max_image_size, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {self.max_image_size}"
157
+ assert outscale > 1 and outscale <= 10, f"Outscale must be between 1 and 10: {outscale}"
158
+ except AssertionError as e:
159
+ self.logger.error(f"Assertion error: {e}")
160
+ return {"image_url": None, "image_key": None, "error": str(e)}
161
+
162
+
163
+ ############################################################
164
+ # Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
165
+ ############################################################
166
+ self.logger.info(f">>> 3/7: CONVERTING IMAGE TO OPENCV BGR/BGRA FORMAT....")
167
+ try:
168
+ opencv_image = np.array(image)
169
+ except Exception as e:
170
+ self.logger.error(f"Error converting image to opencv format: {e}")
171
+ return {"image_url": None, "image_key": None, "error": f"Failed to convert image to opencv format: {e}"}
172
+
173
+ # convert image to BGR
174
+ if in_mode == "RGB":
175
+ self.logger.info(f"converting RGB image to BGR")
176
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
177
+ elif in_mode == "RGBA":
178
+ self.logger.info(f"converting RGBA image to BGRA")
179
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA)
180
+ elif in_mode == "L":
181
+ self.logger.info(f"converting grayscale image to BGR")
182
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB)
183
+ else:
184
+ self.logger.error(f"Unsupported image mode: {in_mode}")
185
+ return {"image_url": None, "image_key": None, "error": f"Unsupported image mode: {in_mode}"}
186
 
187
+
188
+ ############################################################
189
+ # upscale image
190
+ ############################################################
191
+ self.logger.info(f">>> 4/7: UPSCALING IMAGE....")
192
+
193
+ try:
194
+ output, _ = self.model.enhance(opencv_image, outscale=outscale)
195
+ except Exception as e:
196
+ self.logger.error(f"Error enhancing image: {e}")
197
+ return {"image_url": None, "image_key": None, "error": "Image enhancement failed."}
198
+ # debug
199
+ self.logger.info(f"output.shape: {output.shape}")
200
 
201
+
202
+ ############################################################
203
+ # convert to RGB/RGBA format
204
+ ############################################################
205
+ self.logger.info(f">>> 5/7: CONVERTING IMAGE TO RGB/RGBA FORMAT....")
206
+ out_shape = output.shape
207
+ if len(out_shape) == 3:
208
+ if out_shape[2] == 3:
209
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
210
+ elif out_shape[2] == 4:
211
+ output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
212
+ else:
213
+ output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
214
+
215
+
216
+ ############################################################
217
+ # convert to PIL image
218
+ ############################################################
219
+ self.logger.info(f">>> 6/7: CONVERTING IMAGE TO PIL....")
220
+ try:
221
+ img_byte_arr = BytesIO()
222
+ output = Image.fromarray(output)
223
+ except Exception as e:
224
+ self.logger.error(f"Error converting upscaled image to PIL: {e}")
225
+ return {"image_url": None, "image_key": None, "error": f"Failed to convert upscaled image to PIL: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+
228
+ ############################################################
229
+ # upload to s3
230
+ ############################################################
231
+ self.logger.info(f">>> 7/7: UPLOADING IMAGE TO S3....")
232
+ try:
233
+ image_url, key = self.upload_to_s3(output)
234
+ self.logger.info(f"image uploaded to s3: {image_url}")
235
  except Exception as e:
236
+ self.logger.error(f"Error uploading image to s3: {e}")
237
+ return {"image_url": None, "image_key": None, "error": f"Failed to upload image to s3: {e}"}
238
+
239
+
240
+ return {"image_url": image_url,
241
+ "image_key": key,
242
+ "error": None
243
+ }
244
+
245
 
246
+ def upload_to_s3(self, image: Image.Image) -> tuple[str, str]:
247
+ """
248
+ Upload the image to S3 and return the URL and key.
249
+
250
+ Args:
251
+ image (Image.Image): The image to upload.
252
 
253
+ Returns:
254
+ tuple[str, str]: A tuple containing the image URL and the S3 key.
255
+ """
256
 
257
  prefix = str(uuid.uuid4())
258
  # Save the image to an in-memory file
259
  in_mem_file = io.BytesIO()
260
+ image.save(in_mem_file, format='PNG')
261
  in_mem_file.seek(0)
262
 
263
+ # Upload the image to S3
264
  key = f"{prefix}.png"
265
  self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key)
266
+ image_url = f"https://{self.bucket_name}.s3.amazonaws.com/{key}"
267
 
268
+ # Return the URL and the key
269
  return image_url, key
270
 
271
+ def download_image_url(self, image_url: str) -> Image.Image:
272
+ """
273
+ Downloads an image from the specified URL and returns it as a PIL Image.
274
+
275
+ Args:
276
+ image_url (str): The URL of the image to download.
277
+
278
+ Returns:
279
+ Image.Image: The downloaded image as a PIL Image.
280
+ """
281
  response = requests.get(image_url)
282
  image = Image.open(BytesIO(response.content))
283
  return image
test_handler.ipynb CHANGED
@@ -40,7 +40,18 @@
40
  "cell_type": "code",
41
  "execution_count": 3,
42
  "metadata": {},
43
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
44
  "source": [
45
  "# init handler\n",
46
  "my_handler = EndpointHandler(path=\".\")"
@@ -67,14 +78,14 @@
67
  "INFO - >>> 5/7: CONVERTING IMAGE TO RGB/RGBA FORMAT....\n",
68
  "INFO - >>> 6/7: CONVERTING IMAGE TO PIL....\n",
69
  "INFO - >>> 7/7: UPLOADING IMAGE TO S3....\n",
70
- "INFO - image uploaded to s3: https://upscale-process-results.s3.amazonaws.com/fe21c683-ee2d-4e1d-9fbc-af56823c664c.png\n"
71
  ]
72
  },
73
  {
74
  "name": "stdout",
75
  "output_type": "stream",
76
  "text": [
77
- "https://upscale-process-results.s3.amazonaws.com/fe21c683-ee2d-4e1d-9fbc-af56823c664c.png fe21c683-ee2d-4e1d-9fbc-af56823c664c.png\n"
78
  ]
79
  }
80
  ],
 
40
  "cell_type": "code",
41
  "execution_count": 3,
42
  "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stderr",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "INFO - model_path: /workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth\n",
49
+ "INFO - TILING_SIZE is 0, not using tiling\n",
50
+ "INFO - initializing model\n",
51
+ "INFO - model initialized in 1.977891206741333 seconds\n"
52
+ ]
53
+ }
54
+ ],
55
  "source": [
56
  "# init handler\n",
57
  "my_handler = EndpointHandler(path=\".\")"
 
78
  "INFO - >>> 5/7: CONVERTING IMAGE TO RGB/RGBA FORMAT....\n",
79
  "INFO - >>> 6/7: CONVERTING IMAGE TO PIL....\n",
80
  "INFO - >>> 7/7: UPLOADING IMAGE TO S3....\n",
81
+ "INFO - image uploaded to s3: https://jiffy-staging-upscaled-images.s3.amazonaws.com/25b91e15-b785-47ca-81a6-ad5fbdf8b92a.png\n"
82
  ]
83
  },
84
  {
85
  "name": "stdout",
86
  "output_type": "stream",
87
  "text": [
88
+ "https://jiffy-staging-upscaled-images.s3.amazonaws.com/25b91e15-b785-47ca-81a6-ad5fbdf8b92a.png 25b91e15-b785-47ca-81a6-ad5fbdf8b92a.png\n"
89
  ]
90
  }
91
  ],