Jagrut Thakare commited on
Commit
94a0f74
·
1 Parent(s): b0bdcb8
Files changed (9) hide show
  1. .gitignore +124 -0
  2. README.md +2 -2
  3. app.py +59 -0
  4. assets/big-lama.pt +3 -0
  5. requirements.txt +22 -0
  6. src/__init__.py +0 -0
  7. src/core.py +463 -0
  8. src/helper.py +87 -0
  9. src/st_style.py +42 -0
.gitignore ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # celery beat schedule file
94
+ celerybeat-schedule
95
+
96
+ # SageMath parsed files
97
+ *.sage.py
98
+
99
+ # Environments
100
+ .env
101
+ .venv
102
+ env/
103
+ venv/
104
+ ENV/
105
+ env.bak/
106
+ venv.bak/
107
+
108
+ # Spyder project settings
109
+ .spyderproject
110
+ .spyproject
111
+
112
+ # Rope project settings
113
+ .ropeproject
114
+
115
+ # mkdocs documentation
116
+ /site
117
+
118
+ # mypy
119
+ .mypy_cache/
120
+ .dmypy.json
121
+ dmypy.json
122
+
123
+ # Pyre type checker
124
+ .pyre/
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Object Remover
3
- emoji: 🦀
4
  colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.23.1
8
  app_file: app.py
 
1
  ---
2
  title: Object Remover
3
+ emoji:
4
  colorFrom: gray
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.23.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from copy import deepcopy
6
+ from src.core import process_inpaint
7
+ from huggingface_hub import login
8
+ import os
9
+ login(os.getenv("HF_TOKEN"))
10
+
11
+ def process_image(img_input, mask_input, brush_size):
12
+ img_input = Image.open(BytesIO(img_input)).convert("RGBA")
13
+ mask_input = Image.open(BytesIO(mask_input)).convert("RGBA")
14
+
15
+ max_size = 2000
16
+ img_width, img_height = img_input.size
17
+ if img_width > max_size or img_height > max_size:
18
+ if img_width > img_height:
19
+ new_width = max_size
20
+ new_height = int((max_size / img_width) * img_height)
21
+ else:
22
+ new_height = max_size
23
+ new_width = int((max_size / img_height) * img_width)
24
+ img_input = img_input.resize((new_width, new_height))
25
+
26
+ im = np.array(mask_input.resize(img_input.size))
27
+ background = np.where(
28
+ (im[:, :, 0] == 0) &
29
+ (im[:, :, 1] == 0) &
30
+ (im[:, :, 2] == 0)
31
+ )
32
+ drawing = np.where(
33
+ (im[:, :, 0] == 255) &
34
+ (im[:, :, 1] == 0) &
35
+ (im[:, :, 2] == 255)
36
+ )
37
+ im[background] = [0, 0, 0, 255]
38
+ im[drawing] = [0, 0, 0, 0]
39
+
40
+ output = process_inpaint(np.array(img_input), np.array(im))
41
+ img_output = Image.fromarray(output).convert("RGB")
42
+
43
+ output_buffer = BytesIO()
44
+ img_output.save(output_buffer, format="PNG")
45
+ return output_buffer.getvalue()
46
+
47
+ demo = gr.Interface(
48
+ fn=process_image,
49
+ inputs=[
50
+ gr.Image(type="bytes", label="Upload Image"),
51
+ gr.Image(type="bytes", tool="sketch", label="Draw Mask"),
52
+ gr.Slider(1, 100, value=50, label="Brush Size")
53
+ ],
54
+ outputs=gr.Image(type="file", label="Output Image"),
55
+ title="Object Remover",
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
assets/big-lama.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
3
+ size 205669692
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python-headless
5
+ matplotlib
6
+ pyyaml
7
+ tqdm
8
+ easydict
9
+ scikit-image
10
+ scikit-learn
11
+ tensorflow
12
+ joblib
13
+ pandas
14
+ albumentations
15
+ hydra-core
16
+ pytorch-lightning
17
+ tabulate
18
+ kornia
19
+ webdataset
20
+ packaging
21
+ wldhx.yadisk-direct
22
+ altair
src/__init__.py ADDED
File without changes
src/core.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import uuid
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ import cv2
10
+
11
+ # For inpainting
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import streamlit as st
16
+ from PIL import Image
17
+ import argparse
18
+ import io
19
+ import multiprocessing
20
+ from typing import Union
21
+
22
+ import torch
23
+
24
+ try:
25
+ torch._C._jit_override_can_fuse_on_cpu(False)
26
+ torch._C._jit_override_can_fuse_on_gpu(False)
27
+ torch._C._jit_set_texpr_fuser_enabled(False)
28
+ torch._C._jit_set_nvfuser_enabled(False)
29
+ except:
30
+ pass
31
+
32
+ from src.helper import (
33
+ download_model,
34
+ load_img,
35
+ norm_img,
36
+ numpy_to_bytes,
37
+ pad_img_to_modulo,
38
+ resize_max_size,
39
+ )
40
+
41
+ NUM_THREADS = str(multiprocessing.cpu_count())
42
+
43
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
44
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
45
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
46
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
47
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
48
+ if os.environ.get("CACHE_DIR"):
49
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
50
+
51
+ #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
52
+
53
+ # For Seam-carving
54
+
55
+ from scipy import ndimage as ndi
56
+
57
+ SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
58
+ SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
59
+ DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
60
+ ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
61
+ MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
62
+ USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
63
+
64
+ device = torch.device("cpu")
65
+ model_path = "./assets/big-lama.pt"
66
+ model = torch.jit.load(model_path, map_location="cpu")
67
+ model = model.to(device)
68
+ model.eval()
69
+
70
+
71
+ ########################################
72
+ # UTILITY CODE
73
+ ########################################
74
+
75
+
76
+ def visualize(im, boolmask=None, rotate=False):
77
+ vis = im.astype(np.uint8)
78
+ if boolmask is not None:
79
+ vis[np.where(boolmask == False)] = SEAM_COLOR
80
+ if rotate:
81
+ vis = rotate_image(vis, False)
82
+ cv2.imshow("visualization", vis)
83
+ cv2.waitKey(1)
84
+ return vis
85
+
86
+ def resize(image, width):
87
+ dim = None
88
+ h, w = image.shape[:2]
89
+ dim = (width, int(h * width / float(w)))
90
+ image = image.astype('float32')
91
+ return cv2.resize(image, dim)
92
+
93
+ def rotate_image(image, clockwise):
94
+ k = 1 if clockwise else 3
95
+ return np.rot90(image, k)
96
+
97
+
98
+ ########################################
99
+ # ENERGY FUNCTIONS
100
+ ########################################
101
+
102
+ def backward_energy(im):
103
+ """
104
+ Simple gradient magnitude energy map.
105
+ """
106
+ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
107
+ ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
108
+
109
+ grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
110
+
111
+ # vis = visualize(grad_mag)
112
+ # cv2.imwrite("backward_energy_demo.jpg", vis)
113
+
114
+ return grad_mag
115
+
116
+ def forward_energy(im):
117
+ """
118
+ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
119
+ by Rubinstein, Shamir, Avidan.
120
+ Vectorized code adapted from
121
+ https://github.com/axu2/improved-seam-carving.
122
+ """
123
+ h, w = im.shape[:2]
124
+ im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
125
+
126
+ energy = np.zeros((h, w))
127
+ m = np.zeros((h, w))
128
+
129
+ U = np.roll(im, 1, axis=0)
130
+ L = np.roll(im, 1, axis=1)
131
+ R = np.roll(im, -1, axis=1)
132
+
133
+ cU = np.abs(R - L)
134
+ cL = np.abs(U - L) + cU
135
+ cR = np.abs(U - R) + cU
136
+
137
+ for i in range(1, h):
138
+ mU = m[i-1]
139
+ mL = np.roll(mU, 1)
140
+ mR = np.roll(mU, -1)
141
+
142
+ mULR = np.array([mU, mL, mR])
143
+ cULR = np.array([cU[i], cL[i], cR[i]])
144
+ mULR += cULR
145
+
146
+ argmins = np.argmin(mULR, axis=0)
147
+ m[i] = np.choose(argmins, mULR)
148
+ energy[i] = np.choose(argmins, cULR)
149
+
150
+ # vis = visualize(energy)
151
+ # cv2.imwrite("forward_energy_demo.jpg", vis)
152
+
153
+ return energy
154
+
155
+ ########################################
156
+ # SEAM HELPER FUNCTIONS
157
+ ########################################
158
+
159
+ def add_seam(im, seam_idx):
160
+ """
161
+ Add a vertical seam to a 3-channel color image at the indices provided
162
+ by averaging the pixels values to the left and right of the seam.
163
+ Code adapted from https://github.com/vivianhylee/seam-carving.
164
+ """
165
+ h, w = im.shape[:2]
166
+ output = np.zeros((h, w + 1, 3))
167
+ for row in range(h):
168
+ col = seam_idx[row]
169
+ for ch in range(3):
170
+ if col == 0:
171
+ p = np.mean(im[row, col: col + 2, ch])
172
+ output[row, col, ch] = im[row, col, ch]
173
+ output[row, col + 1, ch] = p
174
+ output[row, col + 1:, ch] = im[row, col:, ch]
175
+ else:
176
+ p = np.mean(im[row, col - 1: col + 1, ch])
177
+ output[row, : col, ch] = im[row, : col, ch]
178
+ output[row, col, ch] = p
179
+ output[row, col + 1:, ch] = im[row, col:, ch]
180
+
181
+ return output
182
+
183
+ def add_seam_grayscale(im, seam_idx):
184
+ """
185
+ Add a vertical seam to a grayscale image at the indices provided
186
+ by averaging the pixels values to the left and right of the seam.
187
+ """
188
+ h, w = im.shape[:2]
189
+ output = np.zeros((h, w + 1))
190
+ for row in range(h):
191
+ col = seam_idx[row]
192
+ if col == 0:
193
+ p = np.mean(im[row, col: col + 2])
194
+ output[row, col] = im[row, col]
195
+ output[row, col + 1] = p
196
+ output[row, col + 1:] = im[row, col:]
197
+ else:
198
+ p = np.mean(im[row, col - 1: col + 1])
199
+ output[row, : col] = im[row, : col]
200
+ output[row, col] = p
201
+ output[row, col + 1:] = im[row, col:]
202
+
203
+ return output
204
+
205
+ def remove_seam(im, boolmask):
206
+ h, w = im.shape[:2]
207
+ boolmask3c = np.stack([boolmask] * 3, axis=2)
208
+ return im[boolmask3c].reshape((h, w - 1, 3))
209
+
210
+ def remove_seam_grayscale(im, boolmask):
211
+ h, w = im.shape[:2]
212
+ return im[boolmask].reshape((h, w - 1))
213
+
214
+ def get_minimum_seam(im, mask=None, remove_mask=None):
215
+ """
216
+ DP algorithm for finding the seam of minimum energy. Code adapted from
217
+ https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
218
+ """
219
+ h, w = im.shape[:2]
220
+ energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
221
+ M = energyfn(im)
222
+
223
+ if mask is not None:
224
+ M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
225
+
226
+ # give removal mask priority over protective mask by using larger negative value
227
+ if remove_mask is not None:
228
+ M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
229
+
230
+ seam_idx, boolmask = compute_shortest_path(M, im, h, w)
231
+
232
+ return np.array(seam_idx), boolmask
233
+
234
+ def compute_shortest_path(M, im, h, w):
235
+ backtrack = np.zeros_like(M, dtype=np.int_)
236
+
237
+
238
+ # populate DP matrix
239
+ for i in range(1, h):
240
+ for j in range(0, w):
241
+ if j == 0:
242
+ idx = np.argmin(M[i - 1, j:j + 2])
243
+ backtrack[i, j] = idx + j
244
+ min_energy = M[i-1, idx + j]
245
+ else:
246
+ idx = np.argmin(M[i - 1, j - 1:j + 2])
247
+ backtrack[i, j] = idx + j - 1
248
+ min_energy = M[i - 1, idx + j - 1]
249
+
250
+ M[i, j] += min_energy
251
+
252
+ # backtrack to find path
253
+ seam_idx = []
254
+ boolmask = np.ones((h, w), dtype=np.bool_)
255
+ j = np.argmin(M[-1])
256
+ for i in range(h-1, -1, -1):
257
+ boolmask[i, j] = False
258
+ seam_idx.append(j)
259
+ j = backtrack[i, j]
260
+
261
+ seam_idx.reverse()
262
+ return seam_idx, boolmask
263
+
264
+ ########################################
265
+ # MAIN ALGORITHM
266
+ ########################################
267
+
268
+ def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
269
+ for _ in range(num_remove):
270
+ seam_idx, boolmask = get_minimum_seam(im, mask)
271
+ if vis:
272
+ visualize(im, boolmask, rotate=rot)
273
+ im = remove_seam(im, boolmask)
274
+ if mask is not None:
275
+ mask = remove_seam_grayscale(mask, boolmask)
276
+ return im, mask
277
+
278
+
279
+ def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
280
+ seams_record = []
281
+ temp_im = im.copy()
282
+ temp_mask = mask.copy() if mask is not None else None
283
+
284
+ for _ in range(num_add):
285
+ seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
286
+ if vis:
287
+ visualize(temp_im, boolmask, rotate=rot)
288
+
289
+ seams_record.append(seam_idx)
290
+ temp_im = remove_seam(temp_im, boolmask)
291
+ if temp_mask is not None:
292
+ temp_mask = remove_seam_grayscale(temp_mask, boolmask)
293
+
294
+ seams_record.reverse()
295
+
296
+ for _ in range(num_add):
297
+ seam = seams_record.pop()
298
+ im = add_seam(im, seam)
299
+ if vis:
300
+ visualize(im, rotate=rot)
301
+ if mask is not None:
302
+ mask = add_seam_grayscale(mask, seam)
303
+
304
+ # update the remaining seam indices
305
+ for remaining_seam in seams_record:
306
+ remaining_seam[np.where(remaining_seam >= seam)] += 2
307
+
308
+ return im, mask
309
+
310
+ ########################################
311
+ # MAIN DRIVER FUNCTIONS
312
+ ########################################
313
+
314
+ def seam_carve(im, dy, dx, mask=None, vis=False):
315
+ im = im.astype(np.float64)
316
+ h, w = im.shape[:2]
317
+ assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
318
+
319
+ if mask is not None:
320
+ mask = mask.astype(np.float64)
321
+
322
+ output = im
323
+
324
+ if dx < 0:
325
+ output, mask = seams_removal(output, -dx, mask, vis)
326
+
327
+ elif dx > 0:
328
+ output, mask = seams_insertion(output, dx, mask, vis)
329
+
330
+ if dy < 0:
331
+ output = rotate_image(output, True)
332
+ if mask is not None:
333
+ mask = rotate_image(mask, True)
334
+ output, mask = seams_removal(output, -dy, mask, vis, rot=True)
335
+ output = rotate_image(output, False)
336
+
337
+ elif dy > 0:
338
+ output = rotate_image(output, True)
339
+ if mask is not None:
340
+ mask = rotate_image(mask, True)
341
+ output, mask = seams_insertion(output, dy, mask, vis, rot=True)
342
+ output = rotate_image(output, False)
343
+
344
+ return output
345
+
346
+
347
+ def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
348
+ im = im.astype(np.float64)
349
+ rmask = rmask.astype(np.float64)
350
+ if mask is not None:
351
+ mask = mask.astype(np.float64)
352
+ output = im
353
+
354
+ h, w = im.shape[:2]
355
+
356
+ if horizontal_removal:
357
+ output = rotate_image(output, True)
358
+ rmask = rotate_image(rmask, True)
359
+ if mask is not None:
360
+ mask = rotate_image(mask, True)
361
+
362
+ while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
363
+ seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
364
+ if vis:
365
+ visualize(output, boolmask, rotate=horizontal_removal)
366
+ output = remove_seam(output, boolmask)
367
+ rmask = remove_seam_grayscale(rmask, boolmask)
368
+ if mask is not None:
369
+ mask = remove_seam_grayscale(mask, boolmask)
370
+
371
+ num_add = (h if horizontal_removal else w) - output.shape[1]
372
+ output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
373
+ if horizontal_removal:
374
+ output = rotate_image(output, False)
375
+
376
+ return output
377
+
378
+
379
+
380
+ def s_image(im,mask,vs,hs,mode="resize"):
381
+ im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
382
+ mask = 255-mask[:,:,3]
383
+ h, w = im.shape[:2]
384
+ if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
385
+ im = resize(im, width=DOWNSIZE_WIDTH)
386
+ if mask is not None:
387
+ mask = resize(mask, width=DOWNSIZE_WIDTH)
388
+
389
+ # image resize mode
390
+ if mode=="resize":
391
+ dy = hs#reverse
392
+ dx = vs#reverse
393
+ assert dy is not None and dx is not None
394
+ output = seam_carve(im, dy, dx, mask, False)
395
+
396
+
397
+ # object removal mode
398
+ elif mode=="remove":
399
+ assert mask is not None
400
+ output = object_removal(im, mask, None, False, True)
401
+
402
+ return output
403
+
404
+
405
+ ##### Inpainting helper code
406
+
407
+ def run(image, mask):
408
+ """
409
+ image: [C, H, W]
410
+ mask: [1, H, W]
411
+ return: BGR IMAGE
412
+ """
413
+ origin_height, origin_width = image.shape[1:]
414
+ image = pad_img_to_modulo(image, mod=8)
415
+ mask = pad_img_to_modulo(mask, mod=8)
416
+
417
+ mask = (mask > 0) * 1
418
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
419
+ mask = torch.from_numpy(mask).unsqueeze(0).to(device)
420
+
421
+ start = time.time()
422
+ with torch.no_grad():
423
+ inpainted_image = model(image, mask)
424
+
425
+ print(f"process time: {(time.time() - start)*1000}ms")
426
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
427
+ cur_res = cur_res[0:origin_height, 0:origin_width, :]
428
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
429
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
430
+ return cur_res
431
+
432
+
433
+ def get_args_parser():
434
+ parser = argparse.ArgumentParser()
435
+ parser.add_argument("--port", default=8080, type=int)
436
+ parser.add_argument("--device", default="cuda", type=str)
437
+ parser.add_argument("--debug", action="store_true")
438
+ return parser.parse_args()
439
+
440
+
441
+ def process_inpaint(image, mask):
442
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
443
+ original_shape = image.shape
444
+ interpolation = cv2.INTER_CUBIC
445
+
446
+ #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
447
+ #if size_limit == "Original":
448
+ size_limit = max(image.shape)
449
+ #else:
450
+ # size_limit = int(size_limit)
451
+
452
+ print(f"Origin image shape: {original_shape}")
453
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
454
+ print(f"Resized image shape: {image.shape}")
455
+ image = norm_img(image)
456
+
457
+ mask = 255-mask[:,:,3]
458
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
459
+ mask = norm_img(mask)
460
+
461
+ res_np_img = run(image, mask)
462
+
463
+ return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
src/helper.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from urllib.parse import urlparse
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from torch.hub import download_url_to_file, get_dir
9
+
10
+ LAMA_MODEL_URL = os.environ.get(
11
+ "LAMA_MODEL_URL",
12
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
13
+ )
14
+
15
+
16
+ def download_model(url=LAMA_MODEL_URL):
17
+ parts = urlparse(url)
18
+ hub_dir = get_dir()
19
+ model_dir = os.path.join(hub_dir, "checkpoints")
20
+ if not os.path.isdir(model_dir):
21
+ os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
22
+ filename = os.path.basename(parts.path)
23
+ cached_file = os.path.join(model_dir, filename)
24
+ if not os.path.exists(cached_file):
25
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
26
+ hash_prefix = None
27
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
28
+ return cached_file
29
+
30
+
31
+ def ceil_modulo(x, mod):
32
+ if x % mod == 0:
33
+ return x
34
+ return (x // mod + 1) * mod
35
+
36
+
37
+ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
38
+ data = cv2.imencode(".jpg", image_numpy)[1]
39
+ image_bytes = data.tobytes()
40
+ return image_bytes
41
+
42
+
43
+ def load_img(img_bytes, gray: bool = False):
44
+ nparr = np.frombuffer(img_bytes, np.uint8)
45
+ if gray:
46
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
47
+ else:
48
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
49
+ if len(np_img.shape) == 3 and np_img.shape[2] == 4:
50
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
51
+ else:
52
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
53
+
54
+ return np_img
55
+
56
+
57
+ def norm_img(np_img):
58
+ if len(np_img.shape) == 2:
59
+ np_img = np_img[:, :, np.newaxis]
60
+ np_img = np.transpose(np_img, (2, 0, 1))
61
+ np_img = np_img.astype("float32") / 255
62
+ return np_img
63
+
64
+
65
+ def resize_max_size(
66
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
67
+ ) -> np.ndarray:
68
+ # Resize image's longer size to size_limit if longer size larger than size_limit
69
+ h, w = np_img.shape[:2]
70
+ if max(h, w) > size_limit:
71
+ ratio = size_limit / max(h, w)
72
+ new_w = int(w * ratio + 0.5)
73
+ new_h = int(h * ratio + 0.5)
74
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
75
+ else:
76
+ return np_img
77
+
78
+
79
+ def pad_img_to_modulo(img, mod):
80
+ channels, height, width = img.shape
81
+ out_height = ceil_modulo(height, mod)
82
+ out_width = ceil_modulo(width, mod)
83
+ return np.pad(
84
+ img,
85
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
86
+ mode="symmetric",
87
+ )
src/st_style.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ button_style = """
2
+ <style>
3
+ div.stButton > button:first-child {
4
+ background-color: rgb(255, 75, 75);
5
+ color: rgb(255, 255, 255);
6
+ }
7
+ div.stButton > button:hover {
8
+ background-color: rgb(255, 75, 75);
9
+ color: rgb(255, 255, 255);
10
+ }
11
+ div.stButton > button:active {
12
+ background-color: rgb(255, 75, 75);
13
+ color: rgb(255, 255, 255);
14
+ }
15
+ div.stButton > button:focus {
16
+ background-color: rgb(255, 75, 75);
17
+ color: rgb(255, 255, 255);
18
+ }
19
+ .css-1cpxqw2:focus:not(:active) {
20
+ background-color: rgb(255, 75, 75);
21
+ border-color: rgb(255, 75, 75);
22
+ color: rgb(255, 255, 255);
23
+ }
24
+ """
25
+
26
+ style = """
27
+ <style>
28
+ #MainMenu {
29
+ visibility: hidden;
30
+ }
31
+ footer {
32
+ visibility: hidden;
33
+ }
34
+ header {
35
+ visibility: hidden;
36
+ }
37
+ </style>
38
+ """
39
+
40
+
41
+ def apply_prod_style(st):
42
+ return st.markdown(style, unsafe_allow_html=True)