|
""" |
|
This module handles the endpoint for image upscaling using the Real-ESRGAN model. |
|
|
|
Required Environment Variables: |
|
- TILING_SIZE: The size of the tiles for processing images. Set to 0 to disable tiling. |
|
- AWS_ACCESS_KEY_ID: AWS access key for S3 access. |
|
- AWS_SECRET_ACCESS_KEY: AWS secret key for S3 access. |
|
- BUCKET_NAME: The name of the S3 bucket where images will be uploaded. |
|
|
|
""" |
|
import torch |
|
from PIL import Image |
|
from io import BytesIO |
|
from realesrgan import RealESRGANer |
|
from typing import Dict, List, Any |
|
import os |
|
from pathlib import Path |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
import numpy as np |
|
import cv2 |
|
import PIL |
|
import boto3 |
|
import uuid, io |
|
import torch |
|
import base64 |
|
import requests |
|
import logging |
|
import time |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initializes the EndpointHandler class, setting up the Real-ESRGAN model and S3 client. |
|
|
|
Args: |
|
path (str): Optional path to the model weights. Defaults to an empty string. |
|
|
|
This constructor performs the following actions: |
|
- Configures logging based on environment variables. |
|
- Retrieves the tiling size from environment variables. |
|
- Initializes the Real-ESRGAN model with specified parameters, including scale, model path, and architecture. |
|
- Sets up the S3 client using AWS credentials from environment variables. |
|
- Logs the initialization process and any errors encountered during setup. |
|
""" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s') |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self.tiling_size = int(os.environ["TILING_SIZE"]) |
|
self.model_path = f"/repository/weights/Real-ESRGAN-x4plus.pth" |
|
self.max_image_size = 1400 * 1400 |
|
|
|
|
|
|
|
|
|
self.logger.info(f"model_path: {self.model_path}") |
|
if self.tiling_size == 0: self.logger.info("TILING_SIZE is 0, not using tiling") |
|
else: self.logger.info(f"TILING_SIZE is {self.tiling_size}, using tiling") |
|
|
|
|
|
|
|
start_time = time.time() |
|
self.logger.info(f"initializing model") |
|
try: |
|
self.model = RealESRGANer( |
|
scale=4, |
|
|
|
model_path=self.model_path, |
|
|
|
model= RRDBNet(num_in_ch=3, |
|
num_out_ch=3, |
|
num_feat=64, |
|
num_block=23, |
|
num_grow_ch=32, |
|
scale=4 |
|
), |
|
tile=self.tiling_size, |
|
tile_pad=0, |
|
half=True, |
|
) |
|
self.logger.info(f"model initialized in {time.time() - start_time} seconds") |
|
except Exception as e: |
|
self.logger.error(f"Error initializing model: {e}") |
|
raise e |
|
|
|
|
|
try: |
|
|
|
self.s3 = boto3.client('s3', |
|
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], |
|
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], |
|
) |
|
|
|
self.bucket_name = os.environ["S3_BUCKET_NAME"] |
|
except Exception as e: |
|
self.logger.error(f"Error initializing S3 client: {e}") |
|
raise e |
|
|
|
|
|
|
|
def __call__(self, data: Any) -> Dict[str, List[float]]: |
|
""" |
|
Processes the input data to upscale an image using the Real-ESRGAN model. |
|
|
|
Args: |
|
data (Any): A dictionary containing the input data. It should include: |
|
- 'inputs': A dictionary with the following keys: |
|
- 'image_url' (str): The URL of the image to be upscaled. |
|
- 'outscale' (float): The scaling factor for the upscaling process. |
|
|
|
Returns: |
|
Dict[str, List[float]]: A dictionary containing the results of the upscaling process, which includes: |
|
- 'image_url' (str | None): The URL of the upscaled image or None if an error occurred. |
|
- 'image_key' (str | None): The key for the uploaded image in S3 or None if an error occurred. |
|
- 'error' (str | None): An error message if an error occurred, otherwise None. |
|
""" |
|
|
|
|
|
|
|
|
|
self.logger.info(">>> 1/7: GETTING INPUTS....") |
|
try: |
|
inputs = data.pop("inputs", data) |
|
outscale = float(inputs.pop("outscale", 3)) |
|
self.logger.info(f"outscale: {outscale}") |
|
image_url = inputs["image_url"] |
|
except Exception as e: |
|
self.logger.error(f"Error getting inputs: {e}") |
|
return {"image_url": None, "image_key": None, "error": f"Failed to get inputs: {e}"} |
|
|
|
|
|
try: |
|
self.logger.info(f"downloading image from URL: {image_url}") |
|
image = self.download_image_url(image_url) |
|
except Exception as e: |
|
self.logger.error(f"Error downloading image from URL: {image_url}. Exception: {e}") |
|
return {"image_url": None, "image_key": None, "error": f"Failed to download image: {e}"} |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(">>> 2/7: RUNNING ASSERTIONS ON IMAGE....") |
|
|
|
|
|
in_size, in_mode = image.size, image.mode |
|
self.logger.info(f"image.size: {image.size}, image.mode: {image.mode}") |
|
|
|
|
|
try: |
|
assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}" |
|
if self.tiling_size == 0: |
|
assert in_size[0] * in_size[1] < self.max_image_size, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {self.max_image_size}" |
|
assert outscale > 1 and outscale <= 10, f"Outscale must be between 1 and 10: {outscale}" |
|
except AssertionError as e: |
|
self.logger.error(f"Assertion error: {e}") |
|
return {"image_url": None, "image_key": None, "error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f">>> 3/7: CONVERTING IMAGE TO OPENCV BGR/BGRA FORMAT....") |
|
try: |
|
opencv_image = np.array(image) |
|
except Exception as e: |
|
self.logger.error(f"Error converting image to opencv format: {e}") |
|
return {"image_url": None, "image_key": None, "error": f"Failed to convert image to opencv format: {e}"} |
|
|
|
|
|
if in_mode == "RGB": |
|
self.logger.info(f"converting RGB image to BGR") |
|
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR) |
|
elif in_mode == "RGBA": |
|
self.logger.info(f"converting RGBA image to BGRA") |
|
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA) |
|
elif in_mode == "L": |
|
self.logger.info(f"converting grayscale image to BGR") |
|
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB) |
|
else: |
|
self.logger.error(f"Unsupported image mode: {in_mode}") |
|
return {"image_url": None, "image_key": None, "error": f"Unsupported image mode: {in_mode}"} |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f">>> 4/7: UPSCALING IMAGE....") |
|
|
|
try: |
|
output, _ = self.model.enhance(opencv_image, outscale=outscale) |
|
except Exception as e: |
|
self.logger.error(f"Error enhancing image: {e}") |
|
return {"image_url": None, "image_key": None, "error": "Image enhancement failed."} |
|
|
|
self.logger.info(f"output.shape: {output.shape}") |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f">>> 5/7: CONVERTING IMAGE TO RGB/RGBA FORMAT....") |
|
out_shape = output.shape |
|
if len(out_shape) == 3: |
|
if out_shape[2] == 3: |
|
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) |
|
elif out_shape[2] == 4: |
|
output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) |
|
else: |
|
output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f">>> 6/7: CONVERTING IMAGE TO PIL....") |
|
try: |
|
img_byte_arr = BytesIO() |
|
output = Image.fromarray(output) |
|
except Exception as e: |
|
self.logger.error(f"Error converting upscaled image to PIL: {e}") |
|
return {"image_url": None, "image_key": None, "error": f"Failed to convert upscaled image to PIL: {e}"} |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f">>> 7/7: UPLOADING IMAGE TO S3....") |
|
try: |
|
image_url, key = self.upload_to_s3(output) |
|
self.logger.info(f"image uploaded to s3: {image_url}") |
|
except Exception as e: |
|
self.logger.error(f"Error uploading image to s3: {e}") |
|
return {"image_url": None, "image_key": None, "error": f"Failed to upload image to s3: {e}"} |
|
|
|
|
|
return {"image_url": image_url, |
|
"image_key": key, |
|
"error": None |
|
} |
|
|
|
|
|
def upload_to_s3(self, image: Image.Image) -> tuple[str, str]: |
|
""" |
|
Upload the image to S3 and return the URL and key. |
|
|
|
Args: |
|
image (Image.Image): The image to upload. |
|
|
|
Returns: |
|
tuple[str, str]: A tuple containing the image URL and the S3 key. |
|
""" |
|
|
|
prefix = str(uuid.uuid4()) |
|
|
|
in_mem_file = io.BytesIO() |
|
image.save(in_mem_file, format='PNG') |
|
in_mem_file.seek(0) |
|
|
|
|
|
key = f"{prefix}.png" |
|
self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key) |
|
image_url = f"https://{self.bucket_name}.s3.amazonaws.com/{key}" |
|
|
|
|
|
return image_url, key |
|
|
|
def download_image_url(self, image_url: str) -> Image.Image: |
|
""" |
|
Downloads an image from the specified URL and returns it as a PIL Image. |
|
|
|
Args: |
|
image_url (str): The URL of the image to download. |
|
|
|
Returns: |
|
Image.Image: The downloaded image as a PIL Image. |
|
""" |
|
response = requests.get(image_url) |
|
image = Image.open(BytesIO(response.content)) |
|
return image |