feng2022's picture
getback
d301934
raw
history blame
2.01 kB
from argparse import Namespace
import os
from os.path import join as pjoin
from typing import Optional
import cv2
import torch
from tools import (
parse_face,
match_histogram,
)
from utils.torch_helpers import make_image
from utils.misc import stem
def match_skin_histogram(
imgs: torch.Tensor,
sibling_img: torch.Tensor,
spectral_sensitivity,
im_sibling_dir: str,
mask_dir: str,
matched_hist_fn: Optional[str] = None,
normalize=None, # normalize the range of the tensor
):
"""
Extract the skin of the input and sibling images. Create a new input image by matching
its histogram to the sibling.
"""
# TODO: Currently only allows imgs of batch size 1
im_sibling_dir = os.path.abspath(im_sibling_dir)
mask_dir = os.path.abspath(mask_dir)
img_np = make_image(imgs)[0]
sibling_np = make_image(sibling_img)[0][...,::-1]
# save img, sibling
os.makedirs(im_sibling_dir, exist_ok=True)
im_name, sibling_name = 'input.png', 'sibling.png'
cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
# face parsing
parse_face.main(
Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
)
# match_histogram
mh_args = match_histogram.parse_args(
args=[
pjoin(im_sibling_dir, im_name),
pjoin(im_sibling_dir, sibling_name),
],
namespace=Namespace(
out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
src_mask=pjoin(mask_dir, im_name),
ref_mask=pjoin(mask_dir, sibling_name),
spectral_sensitivity=spectral_sensitivity,
)
)
matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
if normalize is not None:
matched = normalize(matched)
return matched