import base64 import json import os import re import time import uuid from io import BytesIO from pathlib import Path import cv2 # For inpainting import numpy as np import pandas as pd import streamlit as st from PIL import Image import argparse import io import multiprocessing from typing import Union import torch try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass from src.helper import ( download_model, load_img, norm_img, numpy_to_bytes, pad_img_to_modulo, resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS os.environ["MKL_NUM_THREADS"] = NUM_THREADS os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") # For Seam-carving from scipy import ndimage as ndi SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR) SHOULD_DOWNSIZE = True # if True, downsize image for faster carving DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask USE_FORWARD_ENERGY = True # if True, use forward energy algorithm device = torch.device("cpu") model_path = "./assets/big-lama.pt" model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) model.eval() ######################################## # UTILITY CODE ######################################## def visualize(im, boolmask=None, rotate=False): vis = im.astype(np.uint8) if boolmask is not None: vis[np.where(boolmask == False)] = SEAM_COLOR if rotate: vis = rotate_image(vis, False) cv2.imshow("visualization", vis) cv2.waitKey(1) return vis def resize(image, width): dim = None h, w = image.shape[:2] dim = (width, int(h * width / float(w))) image = image.astype('float32') return cv2.resize(image, dim) def rotate_image(image, clockwise): k = 1 if clockwise else 3 return np.rot90(image, k) ######################################## # ENERGY FUNCTIONS ######################################## def backward_energy(im): """ Simple gradient magnitude energy map. """ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap') ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap') grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2)) # vis = visualize(grad_mag) # cv2.imwrite("backward_energy_demo.jpg", vis) return grad_mag def forward_energy(im): """ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting" by Rubinstein, Shamir, Avidan. Vectorized code adapted from https://github.com/axu2/improved-seam-carving. """ h, w = im.shape[:2] im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64) energy = np.zeros((h, w)) m = np.zeros((h, w)) U = np.roll(im, 1, axis=0) L = np.roll(im, 1, axis=1) R = np.roll(im, -1, axis=1) cU = np.abs(R - L) cL = np.abs(U - L) + cU cR = np.abs(U - R) + cU for i in range(1, h): mU = m[i-1] mL = np.roll(mU, 1) mR = np.roll(mU, -1) mULR = np.array([mU, mL, mR]) cULR = np.array([cU[i], cL[i], cR[i]]) mULR += cULR argmins = np.argmin(mULR, axis=0) m[i] = np.choose(argmins, mULR) energy[i] = np.choose(argmins, cULR) # vis = visualize(energy) # cv2.imwrite("forward_energy_demo.jpg", vis) return energy ######################################## # SEAM HELPER FUNCTIONS ######################################## def add_seam(im, seam_idx): """ Add a vertical seam to a 3-channel color image at the indices provided by averaging the pixels values to the left and right of the seam. Code adapted from https://github.com/vivianhylee/seam-carving. """ h, w = im.shape[:2] output = np.zeros((h, w + 1, 3)) for row in range(h): col = seam_idx[row] for ch in range(3): if col == 0: p = np.mean(im[row, col: col + 2, ch]) output[row, col, ch] = im[row, col, ch] output[row, col + 1, ch] = p output[row, col + 1:, ch] = im[row, col:, ch] else: p = np.mean(im[row, col - 1: col + 1, ch]) output[row, : col, ch] = im[row, : col, ch] output[row, col, ch] = p output[row, col + 1:, ch] = im[row, col:, ch] return output def add_seam_grayscale(im, seam_idx): """ Add a vertical seam to a grayscale image at the indices provided by averaging the pixels values to the left and right of the seam. """ h, w = im.shape[:2] output = np.zeros((h, w + 1)) for row in range(h): col = seam_idx[row] if col == 0: p = np.mean(im[row, col: col + 2]) output[row, col] = im[row, col] output[row, col + 1] = p output[row, col + 1:] = im[row, col:] else: p = np.mean(im[row, col - 1: col + 1]) output[row, : col] = im[row, : col] output[row, col] = p output[row, col + 1:] = im[row, col:] return output def remove_seam(im, boolmask): h, w = im.shape[:2] boolmask3c = np.stack([boolmask] * 3, axis=2) return im[boolmask3c].reshape((h, w - 1, 3)) def remove_seam_grayscale(im, boolmask): h, w = im.shape[:2] return im[boolmask].reshape((h, w - 1)) def get_minimum_seam(im, mask=None, remove_mask=None): """ DP algorithm for finding the seam of minimum energy. Code adapted from https://karthikkaranth.me/blog/implementing-seam-carving-with-python/ """ h, w = im.shape[:2] energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy M = energyfn(im) if mask is not None: M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST # give removal mask priority over protective mask by using larger negative value if remove_mask is not None: M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100 seam_idx, boolmask = compute_shortest_path(M, im, h, w) return np.array(seam_idx), boolmask def compute_shortest_path(M, im, h, w): backtrack = np.zeros_like(M, dtype=np.int_) # populate DP matrix for i in range(1, h): for j in range(0, w): if j == 0: idx = np.argmin(M[i - 1, j:j + 2]) backtrack[i, j] = idx + j min_energy = M[i-1, idx + j] else: idx = np.argmin(M[i - 1, j - 1:j + 2]) backtrack[i, j] = idx + j - 1 min_energy = M[i - 1, idx + j - 1] M[i, j] += min_energy # backtrack to find path seam_idx = [] boolmask = np.ones((h, w), dtype=np.bool_) j = np.argmin(M[-1]) for i in range(h-1, -1, -1): boolmask[i, j] = False seam_idx.append(j) j = backtrack[i, j] seam_idx.reverse() return seam_idx, boolmask ######################################## # MAIN ALGORITHM ######################################## def seams_removal(im, num_remove, mask=None, vis=False, rot=False): for _ in range(num_remove): seam_idx, boolmask = get_minimum_seam(im, mask) if vis: visualize(im, boolmask, rotate=rot) im = remove_seam(im, boolmask) if mask is not None: mask = remove_seam_grayscale(mask, boolmask) return im, mask def seams_insertion(im, num_add, mask=None, vis=False, rot=False): seams_record = [] temp_im = im.copy() temp_mask = mask.copy() if mask is not None else None for _ in range(num_add): seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask) if vis: visualize(temp_im, boolmask, rotate=rot) seams_record.append(seam_idx) temp_im = remove_seam(temp_im, boolmask) if temp_mask is not None: temp_mask = remove_seam_grayscale(temp_mask, boolmask) seams_record.reverse() for _ in range(num_add): seam = seams_record.pop() im = add_seam(im, seam) if vis: visualize(im, rotate=rot) if mask is not None: mask = add_seam_grayscale(mask, seam) # update the remaining seam indices for remaining_seam in seams_record: remaining_seam[np.where(remaining_seam >= seam)] += 2 return im, mask ######################################## # MAIN DRIVER FUNCTIONS ######################################## def seam_carve(im, dy, dx, mask=None, vis=False): im = im.astype(np.float64) h, w = im.shape[:2] assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w if mask is not None: mask = mask.astype(np.float64) output = im if dx < 0: output, mask = seams_removal(output, -dx, mask, vis) elif dx > 0: output, mask = seams_insertion(output, dx, mask, vis) if dy < 0: output = rotate_image(output, True) if mask is not None: mask = rotate_image(mask, True) output, mask = seams_removal(output, -dy, mask, vis, rot=True) output = rotate_image(output, False) elif dy > 0: output = rotate_image(output, True) if mask is not None: mask = rotate_image(mask, True) output, mask = seams_insertion(output, dy, mask, vis, rot=True) output = rotate_image(output, False) return output def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False): im = im.astype(np.float64) rmask = rmask.astype(np.float64) if mask is not None: mask = mask.astype(np.float64) output = im h, w = im.shape[:2] if horizontal_removal: output = rotate_image(output, True) rmask = rotate_image(rmask, True) if mask is not None: mask = rotate_image(mask, True) while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0: seam_idx, boolmask = get_minimum_seam(output, mask, rmask) if vis: visualize(output, boolmask, rotate=horizontal_removal) output = remove_seam(output, boolmask) rmask = remove_seam_grayscale(rmask, boolmask) if mask is not None: mask = remove_seam_grayscale(mask, boolmask) num_add = (h if horizontal_removal else w) - output.shape[1] output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal) if horizontal_removal: output = rotate_image(output, False) return output def s_image(im,mask,vs,hs,mode="resize"): im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB) mask = 255-mask[:,:,3] h, w = im.shape[:2] if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH: im = resize(im, width=DOWNSIZE_WIDTH) if mask is not None: mask = resize(mask, width=DOWNSIZE_WIDTH) # image resize mode if mode=="resize": dy = hs#reverse dx = vs#reverse assert dy is not None and dx is not None output = seam_carve(im, dy, dx, mask, False) # object removal mode elif mode=="remove": assert mask is not None output = object_removal(im, mask, None, False, True) return output ##### Inpainting helper code def run(image, mask): """ image: [C, H, W] mask: [1, H, W] return: BGR IMAGE """ origin_height, origin_width = image.shape[1:] image = pad_img_to_modulo(image, mod=8) mask = pad_img_to_modulo(mask, mod=8) mask = (mask > 0) * 1 image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) start = time.time() with torch.no_grad(): inpainted_image = model(image, mask) print(f"process time: {(time.time() - start)*1000}ms") cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) return cur_res def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--debug", action="store_true") return parser.parse_args() def process_inpaint(image, mask): image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_shape = image.shape interpolation = cv2.INTER_CUBIC #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080") #if size_limit == "Original": size_limit = max(image.shape) #else: # size_limit = int(size_limit) print(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) print(f"Resized image shape: {image.shape}") image = norm_img(image) mask = 255-mask[:,:,3] mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) res_np_img = run(image, mask) return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)