English
Inference Endpoints
garg-aayush commited on
Commit
a84e2f3
·
1 Parent(s): a7f6fc5

Write the output image to S3 bucket and return image_url and filename

Browse files
Files changed (3) hide show
  1. handler.py +43 -17
  2. requirements.txt +2 -1
  3. test_handler.ipynb +13 -15
handler.py CHANGED
@@ -9,8 +9,8 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
9
  import numpy as np
10
  import cv2
11
  import PIL
12
-
13
-
14
  import torch
15
  import base64
16
 
@@ -18,24 +18,31 @@ import base64
18
  class EndpointHandler:
19
  def __init__(self, path=""):
20
 
 
21
  self.model = RealESRGANer(
22
- scale=4,
 
23
  model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
24
- # model_path=f"/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth",
25
- # dni_weight=dni_weight,
26
- model= RRDBNet(num_in_ch=3,
27
- num_out_ch=3,
28
- num_feat=64,
29
- num_block=23,
30
- num_grow_ch=32,
31
  scale=4
32
  ),
33
  tile=1000,
34
  tile_pad=20,
35
- # pre_pad=args.pre_pad,
36
  half=True,
37
- # gpu_id=args.gpu_id
38
  )
 
 
 
 
 
 
 
 
39
 
40
  def __call__(self, data: Any) -> Dict[str, List[float]]:
41
 
@@ -89,12 +96,14 @@ class EndpointHandler:
89
  img_byte_arr = BytesIO()
90
  output = Image.fromarray(output)
91
 
92
- # save to BytesIO
93
- output.save(img_byte_arr, format='PNG')
94
- img_str = base64.b64encode(img_byte_arr.getvalue())
95
- img_str = img_str.decode()
 
96
 
97
- return {"out_image": img_str,
 
98
  "error": None
99
  }
100
 
@@ -114,3 +123,20 @@ class EndpointHandler:
114
  except Exception as e:
115
  print(f"Exception: {e}")
116
  return {"out_image": None, "error": "An unexpected error occurred"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import cv2
11
  import PIL
12
+ import boto3
13
+ import uuid, io
14
  import torch
15
  import base64
16
 
 
18
  class EndpointHandler:
19
  def __init__(self, path=""):
20
 
21
+ # Initialize the Real-ESRGAN model with specified parameters
22
  self.model = RealESRGANer(
23
+ scale=4, # Scale factor for the model
24
+ # Path to the pre-trained model weights
25
  model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
26
+ # Initialize the RRDBNet model architecture with specified parameters
27
+ model= RRDBNet(num_in_ch=3,
28
+ num_out_ch=3,
29
+ num_feat=64,
30
+ num_block=23,
31
+ num_grow_ch=32,
 
32
  scale=4
33
  ),
34
  tile=1000,
35
  tile_pad=20,
 
36
  half=True,
 
37
  )
38
+
39
+ # Initialize the S3 client with AWS credentials from environment variables
40
+ self.s3 = boto3.client('s3',
41
+ aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
42
+ aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
43
+ )
44
+ # Get the S3 bucket name from environment variables
45
+ self.bucket_name = os.environ["S3_BUCKET_NAME"]
46
 
47
  def __call__(self, data: Any) -> Dict[str, List[float]]:
48
 
 
96
  img_byte_arr = BytesIO()
97
  output = Image.fromarray(output)
98
 
99
+ # # save to BytesIO
100
+ # output.save(img_byte_arr, format='PNG')
101
+ # img_str = base64.b64encode(img_byte_arr.getvalue())
102
+ # img_str = img_str.decode()
103
+ image_url, key = self.upload_to_s3(output)
104
 
105
+ return {"image_url": image_url,
106
+ "image_key": key,
107
  "error": None
108
  }
109
 
 
123
  except Exception as e:
124
  print(f"Exception: {e}")
125
  return {"out_image": None, "error": "An unexpected error occurred"}
126
+
127
+ def upload_to_s3(self, image):
128
+ "Upload the image to s3 and return the url."
129
+
130
+ prefix = str(uuid.uuid4())
131
+ # Save the image to an in-memory file
132
+ in_mem_file = io.BytesIO()
133
+ image.save(in_mem_file, 'PNG')
134
+ in_mem_file.seek(0)
135
+
136
+ # Upload the image to s3
137
+ key = f"{prefix}.png"
138
+ self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key)
139
+ image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key)
140
+
141
+ # return the url and the key
142
+ return image_url, key
requirements.txt CHANGED
@@ -7,4 +7,5 @@ Pillow
7
  torch>=1.7
8
  torchvision==0.16.2
9
  tqdm
10
- realesrgan
 
 
7
  torch>=1.7
8
  torchvision==0.16.2
9
  tqdm
10
+ realesrgan
11
+ boto3
test_handler.ipynb CHANGED
@@ -20,7 +20,7 @@
20
  "from io import BytesIO\n",
21
  "from PIL import Image\n",
22
  "import cv2\n",
23
- "import random\n"
24
  ]
25
  },
26
  {
@@ -48,7 +48,7 @@
48
  },
49
  {
50
  "cell_type": "code",
51
- "execution_count": 6,
52
  "metadata": {},
53
  "outputs": [
54
  {
@@ -62,14 +62,11 @@
62
  "name": "stdout",
63
  "output_type": "stream",
64
  "text": [
65
- "output.shape: (5170, 12000, 4)\n",
66
- "out_image.size: (12000, 5170)\n",
67
- "image.size: (1056, 1068), image.mode: RGB, outscale: 3.0\n",
68
- "output.shape: (3204, 3168, 3)\n",
69
- "out_image.size: (3168, 3204)\n",
70
- "image.size: (1056, 1068), image.mode: L, outscale: 5.49\n",
71
- "output.shape: (5863, 5797, 3)\n",
72
- "out_image.size: (5797, 5863)\n"
73
  ]
74
  }
75
  ],
@@ -91,9 +88,10 @@
91
  "\n",
92
  "\n",
93
  " output_payload = my_handler(payload)\n",
94
- " out_image = decode_base64_image(output_payload[\"out_image\"])\n",
95
- " print(f\"out_image.size: {out_image.size}\")\n",
96
- " out_image.save(f\"test_data/outputs/{img_name.split('.')[0]}_outscale_{outscale}.png\")\n"
 
97
  ]
98
  },
99
  {
@@ -106,7 +104,7 @@
106
  ],
107
  "metadata": {
108
  "kernelspec": {
109
- "display_name": "Python 3",
110
  "language": "python",
111
  "name": "python3"
112
  },
@@ -124,5 +122,5 @@
124
  }
125
  },
126
  "nbformat": 4,
127
- "nbformat_minor": 2
128
  }
 
20
  "from io import BytesIO\n",
21
  "from PIL import Image\n",
22
  "import cv2\n",
23
+ "import random"
24
  ]
25
  },
26
  {
 
48
  },
49
  {
50
  "cell_type": "code",
51
+ "execution_count": 5,
52
  "metadata": {},
53
  "outputs": [
54
  {
 
62
  "name": "stdout",
63
  "output_type": "stream",
64
  "text": [
65
+ "\tTile 1/2\n",
66
+ "\tTile 2/2\n",
67
+ "\tTile 1/2\n",
68
+ "\tTile 2/2\n",
69
+ "output.shape: (5170, 12000, 4)\n"
 
 
 
70
  ]
71
  }
72
  ],
 
88
  "\n",
89
  "\n",
90
  " output_payload = my_handler(payload)\n",
91
+ " # out_image = decode_base64_image(output_payload[\"out_image\"])\n",
92
+ " # print(f\"out_image.size: {out_image.size}\")\n",
93
+ " # out_image.save(f\"test_data/outputs/{img_name.split('.')[0]}_outscale_{outscale}.png\")\n",
94
+ " break"
95
  ]
96
  },
97
  {
 
104
  ],
105
  "metadata": {
106
  "kernelspec": {
107
+ "display_name": "Python 3 (ipykernel)",
108
  "language": "python",
109
  "name": "python3"
110
  },
 
122
  }
123
  },
124
  "nbformat": 4,
125
+ "nbformat_minor": 4
126
  }