garg-aayush
commited on
Commit
·
a84e2f3
1
Parent(s):
a7f6fc5
Write the output image to S3 bucket and return image_url and filename
Browse files- handler.py +43 -17
- requirements.txt +2 -1
- test_handler.ipynb +13 -15
handler.py
CHANGED
@@ -9,8 +9,8 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import PIL
|
12 |
-
|
13 |
-
|
14 |
import torch
|
15 |
import base64
|
16 |
|
@@ -18,24 +18,31 @@ import base64
|
|
18 |
class EndpointHandler:
|
19 |
def __init__(self, path=""):
|
20 |
|
|
|
21 |
self.model = RealESRGANer(
|
22 |
-
scale=4,
|
|
|
23 |
model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
num_grow_ch=32,
|
31 |
scale=4
|
32 |
),
|
33 |
tile=1000,
|
34 |
tile_pad=20,
|
35 |
-
# pre_pad=args.pre_pad,
|
36 |
half=True,
|
37 |
-
# gpu_id=args.gpu_id
|
38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def __call__(self, data: Any) -> Dict[str, List[float]]:
|
41 |
|
@@ -89,12 +96,14 @@ class EndpointHandler:
|
|
89 |
img_byte_arr = BytesIO()
|
90 |
output = Image.fromarray(output)
|
91 |
|
92 |
-
# save to BytesIO
|
93 |
-
output.save(img_byte_arr, format='PNG')
|
94 |
-
img_str = base64.b64encode(img_byte_arr.getvalue())
|
95 |
-
img_str = img_str.decode()
|
|
|
96 |
|
97 |
-
return {"
|
|
|
98 |
"error": None
|
99 |
}
|
100 |
|
@@ -114,3 +123,20 @@ class EndpointHandler:
|
|
114 |
except Exception as e:
|
115 |
print(f"Exception: {e}")
|
116 |
return {"out_image": None, "error": "An unexpected error occurred"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import PIL
|
12 |
+
import boto3
|
13 |
+
import uuid, io
|
14 |
import torch
|
15 |
import base64
|
16 |
|
|
|
18 |
class EndpointHandler:
|
19 |
def __init__(self, path=""):
|
20 |
|
21 |
+
# Initialize the Real-ESRGAN model with specified parameters
|
22 |
self.model = RealESRGANer(
|
23 |
+
scale=4, # Scale factor for the model
|
24 |
+
# Path to the pre-trained model weights
|
25 |
model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
|
26 |
+
# Initialize the RRDBNet model architecture with specified parameters
|
27 |
+
model= RRDBNet(num_in_ch=3,
|
28 |
+
num_out_ch=3,
|
29 |
+
num_feat=64,
|
30 |
+
num_block=23,
|
31 |
+
num_grow_ch=32,
|
|
|
32 |
scale=4
|
33 |
),
|
34 |
tile=1000,
|
35 |
tile_pad=20,
|
|
|
36 |
half=True,
|
|
|
37 |
)
|
38 |
+
|
39 |
+
# Initialize the S3 client with AWS credentials from environment variables
|
40 |
+
self.s3 = boto3.client('s3',
|
41 |
+
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
|
42 |
+
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
|
43 |
+
)
|
44 |
+
# Get the S3 bucket name from environment variables
|
45 |
+
self.bucket_name = os.environ["S3_BUCKET_NAME"]
|
46 |
|
47 |
def __call__(self, data: Any) -> Dict[str, List[float]]:
|
48 |
|
|
|
96 |
img_byte_arr = BytesIO()
|
97 |
output = Image.fromarray(output)
|
98 |
|
99 |
+
# # save to BytesIO
|
100 |
+
# output.save(img_byte_arr, format='PNG')
|
101 |
+
# img_str = base64.b64encode(img_byte_arr.getvalue())
|
102 |
+
# img_str = img_str.decode()
|
103 |
+
image_url, key = self.upload_to_s3(output)
|
104 |
|
105 |
+
return {"image_url": image_url,
|
106 |
+
"image_key": key,
|
107 |
"error": None
|
108 |
}
|
109 |
|
|
|
123 |
except Exception as e:
|
124 |
print(f"Exception: {e}")
|
125 |
return {"out_image": None, "error": "An unexpected error occurred"}
|
126 |
+
|
127 |
+
def upload_to_s3(self, image):
|
128 |
+
"Upload the image to s3 and return the url."
|
129 |
+
|
130 |
+
prefix = str(uuid.uuid4())
|
131 |
+
# Save the image to an in-memory file
|
132 |
+
in_mem_file = io.BytesIO()
|
133 |
+
image.save(in_mem_file, 'PNG')
|
134 |
+
in_mem_file.seek(0)
|
135 |
+
|
136 |
+
# Upload the image to s3
|
137 |
+
key = f"{prefix}.png"
|
138 |
+
self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key)
|
139 |
+
image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key)
|
140 |
+
|
141 |
+
# return the url and the key
|
142 |
+
return image_url, key
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ Pillow
|
|
7 |
torch>=1.7
|
8 |
torchvision==0.16.2
|
9 |
tqdm
|
10 |
-
realesrgan
|
|
|
|
7 |
torch>=1.7
|
8 |
torchvision==0.16.2
|
9 |
tqdm
|
10 |
+
realesrgan
|
11 |
+
boto3
|
test_handler.ipynb
CHANGED
@@ -20,7 +20,7 @@
|
|
20 |
"from io import BytesIO\n",
|
21 |
"from PIL import Image\n",
|
22 |
"import cv2\n",
|
23 |
-
"import random
|
24 |
]
|
25 |
},
|
26 |
{
|
@@ -48,7 +48,7 @@
|
|
48 |
},
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
-
"execution_count":
|
52 |
"metadata": {},
|
53 |
"outputs": [
|
54 |
{
|
@@ -62,14 +62,11 @@
|
|
62 |
"name": "stdout",
|
63 |
"output_type": "stream",
|
64 |
"text": [
|
65 |
-
"
|
66 |
-
"
|
67 |
-
"
|
68 |
-
"
|
69 |
-
"
|
70 |
-
"image.size: (1056, 1068), image.mode: L, outscale: 5.49\n",
|
71 |
-
"output.shape: (5863, 5797, 3)\n",
|
72 |
-
"out_image.size: (5797, 5863)\n"
|
73 |
]
|
74 |
}
|
75 |
],
|
@@ -91,9 +88,10 @@
|
|
91 |
"\n",
|
92 |
"\n",
|
93 |
" output_payload = my_handler(payload)\n",
|
94 |
-
" out_image = decode_base64_image(output_payload[\"out_image\"])\n",
|
95 |
-
" print(f\"out_image.size: {out_image.size}\")\n",
|
96 |
-
" out_image.save(f\"test_data/outputs/{img_name.split('.')[0]}_outscale_{outscale}.png\")\n"
|
|
|
97 |
]
|
98 |
},
|
99 |
{
|
@@ -106,7 +104,7 @@
|
|
106 |
],
|
107 |
"metadata": {
|
108 |
"kernelspec": {
|
109 |
-
"display_name": "Python 3",
|
110 |
"language": "python",
|
111 |
"name": "python3"
|
112 |
},
|
@@ -124,5 +122,5 @@
|
|
124 |
}
|
125 |
},
|
126 |
"nbformat": 4,
|
127 |
-
"nbformat_minor":
|
128 |
}
|
|
|
20 |
"from io import BytesIO\n",
|
21 |
"from PIL import Image\n",
|
22 |
"import cv2\n",
|
23 |
+
"import random"
|
24 |
]
|
25 |
},
|
26 |
{
|
|
|
48 |
},
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
+
"execution_count": 5,
|
52 |
"metadata": {},
|
53 |
"outputs": [
|
54 |
{
|
|
|
62 |
"name": "stdout",
|
63 |
"output_type": "stream",
|
64 |
"text": [
|
65 |
+
"\tTile 1/2\n",
|
66 |
+
"\tTile 2/2\n",
|
67 |
+
"\tTile 1/2\n",
|
68 |
+
"\tTile 2/2\n",
|
69 |
+
"output.shape: (5170, 12000, 4)\n"
|
|
|
|
|
|
|
70 |
]
|
71 |
}
|
72 |
],
|
|
|
88 |
"\n",
|
89 |
"\n",
|
90 |
" output_payload = my_handler(payload)\n",
|
91 |
+
" # out_image = decode_base64_image(output_payload[\"out_image\"])\n",
|
92 |
+
" # print(f\"out_image.size: {out_image.size}\")\n",
|
93 |
+
" # out_image.save(f\"test_data/outputs/{img_name.split('.')[0]}_outscale_{outscale}.png\")\n",
|
94 |
+
" break"
|
95 |
]
|
96 |
},
|
97 |
{
|
|
|
104 |
],
|
105 |
"metadata": {
|
106 |
"kernelspec": {
|
107 |
+
"display_name": "Python 3 (ipykernel)",
|
108 |
"language": "python",
|
109 |
"name": "python3"
|
110 |
},
|
|
|
122 |
}
|
123 |
},
|
124 |
"nbformat": 4,
|
125 |
+
"nbformat_minor": 4
|
126 |
}
|