from typing import List, Dict, Literal, Union, Tuple import os import string import logging import torch import numpy as np from einops import rearrange, repeat logger = logging.getLogger(__name__) def generate_tasks_of_dir( path: str, output_dir: str, exts: Tuple[str], same_dir_name: bool = False, **kwargs, ) -> List[Dict]: """covert video directory into tasks Args: path (str): _description_ output_dir (str): _description_ exts (Tuple[str]): _description_ same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False. whether keep the same parent dir name as the source video Returns: List[Dict]: _description_ """ tasks = [] for rootdir, dirs, files in os.walk(path): for basename in files: if basename.lower().endswith(exts): video_path = os.path.join(rootdir, basename) filename, ext = basename.split(".") rootdir_name = os.path.basename(rootdir) if same_dir_name: save_path = os.path.join( output_dir, rootdir_name, f"{filename}.h5py" ) save_dir = os.path.join(output_dir, rootdir_name) else: save_path = os.path.join(output_dir, f"{filename}.h5py") save_dir = output_dir task = { "video_path": video_path, "output_path": save_path, "output_dir": save_dir, "filename": filename, "ext": ext, } task.update(kwargs) tasks.append(task) return tasks def sample_by_idx( T: int, n_sample: int, sample_rate: int, sample_start_idx: int = None, change_sample_rate: bool = False, seed: int = None, whether_random: bool = True, n_independent: int = 0, ) -> List[int]: """given a int to represent candidate list, sample n_sample with sample_rate from the candidate list Args: T (int): _description_ n_sample (int): 目标采样数目. sample number sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0. change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False. whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False. Raises: ValueError: T / sample_rate should be larger than n_sample Returns: List[int]: 采样的索引位置. sampled index position """ if T < n_sample: raise ValueError(f"T({T}) < n_sample({n_sample})") else: if T / sample_rate < n_sample: if not change_sample_rate: raise ValueError( f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})" ) else: while T / sample_rate < n_sample: sample_rate -= 1 logger.error( f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}" ) if sample_rate == 0: raise ValueError("T / sample_rate < n_sample") if sample_start_idx is None: if whether_random: sample_start_idx_candidates = np.arange(T - n_sample * sample_rate) if seed is not None: np.random.seed(seed) sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0] else: sample_start_idx = 0 sample_end_idx = sample_start_idx + sample_rate * n_sample sample = list(range(sample_start_idx, sample_end_idx, sample_rate)) if n_independent == 0: n_independent_sample = None else: left_candidate = np.array( list(range(0, sample_start_idx)) + list(range(sample_end_idx, T)) ) if len(left_candidate) >= n_independent: # 使用两端的剩余空间采样, use the left space to sample n_independent_sample = np.random.choice(left_candidate, n_independent) else: # 当两端没有剩余采样空间时,使用任意不是sample中的帧 # if no enough space to sample, use any frame not in sample left_candidate = np.array(list(set(range(T) - set(sample)))) n_independent_sample = np.random.choice(left_candidate, n_independent) return sample, sample_rate, n_independent_sample def sample_tensor_by_idx( tensor: Union[torch.Tensor, np.ndarray], n_sample: int, sample_rate: int, sample_start_idx: int = 0, change_sample_rate: bool = False, seed: int = None, dim: int = 0, return_type: Literal["numpy", "torch"] = "torch", whether_random: bool = True, n_independent: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: """sample sub_tensor Args: tensor (Union[torch.Tensor, np.ndarray]): _description_ n_sample (int): _description_ sample_rate (int): _description_ sample_start_idx (int, optional): _description_. Defaults to 0. change_sample_rate (bool, optional): _description_. Defaults to False. seed (int, optional): _description_. Defaults to None. dim (int, optional): _description_. Defaults to 0. return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch". whether_random (bool, optional): _description_. Defaults to True. n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0. n_independent sample number that is independent of n_sample Returns: Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor """ if isinstance(tensor, np.ndarray): tensor = torch.from_numpy(tensor) T = tensor.shape[dim] sample_idx, sample_rate, independent_sample_idx = sample_by_idx( T, n_sample, sample_rate, sample_start_idx, change_sample_rate, seed, whether_random=whether_random, n_independent=n_independent, ) sample_idx = torch.LongTensor(sample_idx) sample = torch.index_select(tensor, dim, sample_idx) if independent_sample_idx is not None: independent_sample_idx = torch.LongTensor(independent_sample_idx) independent_sample = torch.index_select(tensor, dim, independent_sample_idx) else: independent_sample = None independent_sample_idx = None if return_type == "numpy": sample = sample.cpu().numpy() return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx def concat_two_tensor( data1: torch.Tensor, data2: torch.Tensor, dim: int, method: Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ] = "first_in_first_out", data1_index: torch.long = None, data2_index: torch.long = None, return_index: bool = False, ): """concat two tensor along dim with given method Args: data1 (torch.Tensor): first in data data2 (torch.Tensor): last in data dim (int): _description_ method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out". Raises: NotImplementedError: unsupported method ValueError: unsupported method Returns: _type_: _description_ """ len_data1 = data1.shape[dim] len_data2 = data2.shape[dim] if method == "first_in_first_out": res = torch.concat([data1, data2], dim=dim) data1_index = range(len_data1) data2_index = [len_data1 + x for x in range(len_data2)] elif method == "first_in_last_out": res = torch.concat([data2, data1], dim=dim) data2_index = range(len_data2) data1_index = [len_data2 + x for x in range(len_data1)] elif method == "intertwine": raise NotImplementedError("intertwine") elif method == "index": res = concat_two_tensor_with_index( data1=data1, data1_index=data1_index, data2=data2, data2_index=data2_index, dim=dim, ) else: raise ValueError( "only support first_in_first_out, first_in_last_out, intertwine, index" ) if return_index: return res, data1_index, data2_index else: return res def concat_two_tensor_with_index( data1: torch.Tensor, data1_index: torch.LongTensor, data2: torch.Tensor, data2_index: torch.LongTensor, dim: int, ) -> torch.Tensor: """_summary_ Args: data1 (torch.Tensor): b1*c1*h1*w1*... data1_index (torch.LongTensor): N, if dim=1, N=c1 data2 (torch.Tensor): b2*c2*h2*w2*... data2_index (torch.LongTensor): M, if dim=1, M=c2 dim (int): int Returns: torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,... """ shape1 = list(data1.shape) shape2 = list(data2.shape) target_shape = list(shape1) target_shape[dim] = shape1[dim] + shape2[dim] target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) target = batch_index_copy(target, dim=dim, index=data1_index, source=data1) target = batch_index_copy(target, dim=dim, index=data2_index, source=data2) return target def repeat_index_to_target_size( index: torch.LongTensor, target_size: int ) -> torch.LongTensor: if len(index.shape) == 1: index = repeat(index, "n -> b n", b=target_size) if len(index.shape) == 2: remainder = target_size % index.shape[0] assert ( remainder == 0 ), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}" index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0])) return index def batch_concat_two_tensor_with_index( data1: torch.Tensor, data1_index: torch.LongTensor, data2: torch.Tensor, data2_index: torch.LongTensor, dim: int, ) -> torch.Tensor: return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim) def interwine_two_tensor( data1: torch.Tensor, data2: torch.Tensor, dim: int, return_index: bool = False, ) -> torch.Tensor: shape1 = list(data1.shape) shape2 = list(data2.shape) target_shape = list(shape1) target_shape[dim] = shape1[dim] + shape2[dim] target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) data1_reshape = torch.swapaxes(data1, 0, dim) data2_reshape = torch.swapaxes(data2, 0, dim) target = torch.swapaxes(target, 0, dim) total_index = set(range(target_shape[dim])) data1_index = range(0, 2 * shape1[dim], 2) data2_index = sorted(list(set(total_index) - set(data1_index))) data1_index = torch.LongTensor(data1_index) data2_index = torch.LongTensor(data2_index) target[data1_index, ...] = data1_reshape target[data2_index, ...] = data2_reshape target = torch.swapaxes(target, 0, dim) if return_index: return target, data1_index, data2_index else: return target def split_index( indexs: torch.Tensor, n_first: int = None, n_last: int = None, method: Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random" ] = "first_in_first_out", ): """_summary_ Args: indexs (List): _description_ n_first (int): _description_ n_last (int): _description_ method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out". Raises: NotImplementedError: _description_ Returns: first_index: _description_ last_index: """ # assert ( # n_first is None and n_last is None # ), "must assign one value for n_first or n_last" n_total = len(indexs) if n_first is None: n_first = n_total - n_last if n_last is None: n_last = n_total - n_first assert len(indexs) == n_first + n_last if method == "first_in_first_out": first_index = indexs[:n_first] last_index = indexs[n_first:] elif method == "first_in_last_out": first_index = indexs[n_last:] last_index = indexs[:n_last] elif method == "intertwine": raise NotImplementedError elif method == "random": idx_ = torch.randperm(len(indexs)) first_index = indexs[idx_[:n_first]] last_index = indexs[idx_[n_first:]] return first_index, last_index def split_tensor( tensor: torch.Tensor, dim: int, n_first=None, n_last=None, method: Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index", "random" ] = "first_in_first_out", need_return_index: bool = False, ): device = tensor.device total = tensor.shape[dim] if n_first is None: n_first = total - n_last if n_last is None: n_last = total - n_first indexs = torch.arange( total, dtype=torch.long, device=device, ) ( first_index, last_index, ) = split_index( indexs=indexs, n_first=n_first, method=method, ) first_tensor = torch.index_select(tensor, dim=dim, index=first_index) last_tensor = torch.index_select(tensor, dim=dim, index=last_index) if need_return_index: return ( first_tensor, last_tensor, first_index, last_index, ) else: return (first_tensor, last_tensor) # TODO: 待确定batch_index_select的优化 def batch_index_select( tensor: torch.Tensor, index: torch.LongTensor, dim: int ) -> torch.Tensor: """_summary_ Args: tensor (torch.Tensor): D1*D2*D3*D4... index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim] dim (int): dim to select Returns: torch.Tensor: D1*...*N*... """ # TODO: now only support N same for every d1 if len(index.shape) == 1: return torch.index_select(tensor, dim=dim, index=index) else: index = repeat_index_to_target_size(index, tensor.shape[0]) out = [] for i in torch.arange(tensor.shape[0]): sub_tensor = tensor[i] sub_index = index[i] d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index) out.append(d) return torch.stack(out).to(dtype=tensor.dtype) def batch_index_copy( tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor ) -> torch.Tensor: """_summary_ Args: tensor (torch.Tensor): b*c*h dim (int): index (torch.LongTensor): b*d, source (torch.Tensor): b*d*h*..., if dim=1 b*c*d*..., if dim=2 Returns: torch.Tensor: b*c*d*... """ if len(index.shape) == 1: tensor.index_copy_(dim=dim, index=index, source=source) else: index = repeat_index_to_target_size(index, tensor.shape[0]) batch_size = tensor.shape[0] for b in torch.arange(batch_size): sub_index = index[b] sub_source = source[b] sub_tensor = tensor[b] sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source) tensor[b] = sub_tensor return tensor def batch_index_fill( tensor: torch.Tensor, dim: int, index: torch.LongTensor, value: Literal[torch.Tensor, torch.float], ) -> torch.Tensor: """_summary_ Args: tensor (torch.Tensor): b*c*h dim (int): index (torch.LongTensor): b*d, value (torch.Tensor): b Returns: torch.Tensor: b*c*d*... """ index = repeat_index_to_target_size(index, tensor.shape[0]) batch_size = tensor.shape[0] for b in torch.arange(batch_size): sub_index = index[b] sub_value = value[b] if isinstance(value, torch.Tensor) else value sub_tensor = tensor[b] sub_tensor.index_fill_(dim - 1, sub_index, sub_value) tensor[b] = sub_tensor return tensor def adaptive_instance_normalization( src: torch.Tensor, dst: torch.Tensor, eps: float = 1e-6, ): """ Args: src (torch.Tensor): b c t h w dst (torch.Tensor): b c t h w """ ndim = src.ndim if ndim == 5: dim = (2, 3, 4) elif ndim == 4: dim = (2, 3) elif ndim == 3: dim = 2 else: raise ValueError("only support ndim in [3,4,5], but given {ndim}") var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0) mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0) # mean_acc = sum(mean_acc) / float(len(mean_acc)) # var_acc = sum(var_acc) / float(len(var_acc)) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 src = (((src - mean) / std) * std_acc) + mean_acc return src def adaptive_instance_normalization_with_ref( src: torch.LongTensor, dst: torch.LongTensor, style_fidelity: float = 0.5, do_classifier_free_guidance: bool = True, ): # logger.debug( # f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n" # f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}" # ) batch_size = src.shape[0] // 2 uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool() src_uc = adaptive_instance_normalization(src, dst) src_c = src_uc.clone() # TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True if do_classifier_free_guidance and style_fidelity > 0: src_c[uc_mask] = src[uc_mask] src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc return src def batch_adain_conditioned_tensor( tensor: torch.Tensor, src_index: torch.LongTensor, dst_index: torch.LongTensor, keep_dim: bool = True, num_frames: int = None, dim: int = 2, style_fidelity: float = 0.5, do_classifier_free_guidance: bool = True, need_style_fidelity: bool = False, ): """_summary_ Args: tensor (torch.Tensor): b c t h w src_index (torch.LongTensor): _description_ dst_index (torch.LongTensor): _description_ keep_dim (bool, optional): _description_. Defaults to True. Returns: _type_: _description_ """ ndim = tensor.ndim dtype = tensor.dtype if ndim == 4 and num_frames is not None: tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames) src = batch_index_select(tensor, dim=dim, index=src_index).contiguous() dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous() if need_style_fidelity: src = adaptive_instance_normalization_with_ref( src=src, dst=dst, style_fidelity=style_fidelity, do_classifier_free_guidance=do_classifier_free_guidance, need_style_fidelity=need_style_fidelity, ) else: src = adaptive_instance_normalization( src=src, dst=dst, ) if keep_dim: src = batch_concat_two_tensor_with_index( src.to(dtype=dtype), src_index, dst.to(dtype=dtype), dst_index, dim=dim, ) if ndim == 4 and num_frames is not None: src = rearrange(tensor, "b c t h w ->(b t) c h w") return src def align_repeat_tensor_single_dim( src: torch.Tensor, target_length: int, dim: int = 0, n_src_base_length: int = 1, src_base_index: List[int] = None, ) -> torch.Tensor: """沿着 dim 纬度, 补齐 src 的长度到目标 target_length。 当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length align length of src to target_length along dim when src length is less than target_length, take the first n_src_base_length and repeat to target_length Args: src (torch.Tensor): 输入 tensor, input tensor target_length (int): 目标长度, target_length dim (int, optional): 处理纬度, target dim . Defaults to 0. n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1. Returns: torch.Tensor: _description_ """ src_dim_length = src.shape[dim] if target_length > src_dim_length: if target_length % src_dim_length == 0: new = src.repeat_interleave( repeats=target_length // src_dim_length, dim=dim ) else: if src_base_index is None and n_src_base_length is not None: src_base_index = torch.arange(n_src_base_length) new = src.index_select( dim=dim, index=torch.LongTensor(src_base_index).to(device=src.device), ) new = new.repeat_interleave( repeats=target_length // len(src_base_index), dim=dim, ) elif target_length < src_dim_length: new = src.index_select( dim=dim, index=torch.LongTensor(torch.arange(target_length)).to(device=src.device), ) else: new = src return new def fuse_part_tensor( src: torch.Tensor, dst: torch.Tensor, overlap: int, weight: float = 0.5, skip_step: int = 0, ) -> torch.Tensor: """fuse overstep tensor with weight of src into dst out = src_fused_part * weight + dst * (1-weight) for overlap Args: src (torch.Tensor): b c t h w dst (torch.Tensor): b c t h w overlap (int): 1 weight (float, optional): weight of src tensor part. Defaults to 0.5. Returns: torch.Tensor: fused tensor """ if overlap == 0: return dst else: dst[:, :, skip_step : skip_step + overlap] = ( weight * src[:, :, -overlap:] + (1 - weight) * dst[:, :, skip_step : skip_step + overlap] ) return dst