import torch def create_mask_sequence(mask_cfg, seq_len): type_name = mask_cfg['type'] if type_name == 'raster order': num_tokens = mask_cfg['num_tokens'] idx_list = [] all_idx = torch.arange(seq_len) for i in range(0, seq_len, num_tokens): idx_list.append(all_idx[i: i + num_tokens]) return idx_list elif type_name == 'random order': num_tokens = mask_cfg['num_tokens'] idx_list = [] all_idx = torch.randperm(seq_len) for i in range(0, seq_len, num_tokens): idx_list.append(all_idx[i: i + num_tokens]) return idx_list elif type_name == 'single': idx_list = [torch.arange(seq_len)] return idx_list else: raise NotImplementedError()