File size: 5,174 Bytes
9842c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import argparse
import cv2
import numpy as np
import os
import sys
from basicsr.utils import scandir
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def main(args):
"""A multi-thread tool to crop large images to sub-images for faster IO.
opt (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
After process, each sub_folder should have the same number of subimages.
Remember to modify opt configurations according to your settings.
"""
opt = {}
opt["n_thread"] = args.n_thread
opt["compression_level"] = args.compression_level
opt["input_folder"] = args.input
opt["save_folder"] = args.output
opt["crop_size"] = args.crop_size
opt["step"] = args.step
opt["thresh_size"] = args.thresh_size
extract_subimages(opt)
def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt["input_folder"]
save_folder = opt["save_folder"]
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f"mkdir {save_folder} ...")
else:
print(f"Folder {save_folder} already exists. Exit.")
sys.exit(1)
# scan all images
img_list = list(scandir(input_folder, full_path=True))
pbar = tqdm(total=len(img_list), unit="image", desc="Extract")
pool = Pool(opt["n_thread"])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
pbar.close()
print("All processes done.")
def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt["crop_size"]
step = opt["step"]
thresh_size = opt["thresh_size"]
img_name, extension = osp.splitext(osp.basename(path))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name = (
img_name.replace("x2", "").replace("x3", "").replace("x4", "").replace("x8", "")
)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)
index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x : x + crop_size, y : y + crop_size, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(
osp.join(opt["save_folder"], f"{img_name}_s{index:03d}{extension}"),
cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt["compression_level"]],
)
process_info = f"Processing {img_name} ..."
return process_info
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input", type=str, default="datasets/DF2K/DF2K_HR", help="Input folder"
)
parser.add_argument(
"--output", type=str, default="datasets/DF2K/DF2K_HR_sub", help="Output folder"
)
parser.add_argument("--crop_size", type=int, default=480, help="Crop size")
parser.add_argument(
"--step", type=int, default=240, help="Step for overlapped sliding window"
)
parser.add_argument(
"--thresh_size",
type=int,
default=0,
help="Threshold size. Patches whose size is lower than thresh_size will be dropped.",
)
parser.add_argument("--n_thread", type=int, default=20, help="Thread number.")
parser.add_argument(
"--compression_level", type=int, default=3, help="Compression level"
)
args = parser.parse_args()
main(args)
|