import logging
import os
import torch
import torch.distributed as dist
import yaml

from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str


logger = logging.getLogger(__name__)


NORM_MODULES = [
    torch.nn.BatchNorm1d,
    torch.nn.BatchNorm2d,
    torch.nn.BatchNorm3d,
    torch.nn.SyncBatchNorm,
    # NaiveSyncBatchNorm inherits from BatchNorm2d
    torch.nn.GroupNorm,
    torch.nn.InstanceNorm1d,
    torch.nn.InstanceNorm2d,
    torch.nn.InstanceNorm3d,
    torch.nn.LayerNorm,
    torch.nn.LocalResponseNorm,
]

def register_norm_module(cls):
    NORM_MODULES.append(cls)

    return cls


def is_main_process():
    rank = 0
    if 'OMPI_COMM_WORLD_SIZE' in os.environ:
        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])

    return rank == 0


@torch.no_grad()
def analysis_model(model, dump_input, verbose=False):
    model.eval()
    flops = FlopCountAnalysis(model, dump_input)
    total = flops.total()
    model.train()
    params_total = sum(p.numel() for p in model.parameters())
    params_learned = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    logger.info(f"flop count table:\n {flop_count_table(flops)}")
    if verbose:
        logger.info(f"flop count str:\n {flop_count_str(flops)}")
    logger.info(f"  Total flops: {total/1000/1000:.3f}M,")
    logger.info(f"  Total params: {params_total/1000/1000:.3f}M,")
    logger.info(f"  Learned params: {params_learned/1000/1000:.3f}M")

    return total, flop_count_table(flops), flop_count_str(flops)


def load_config_dict_to_opt(opt, config_dict, splitter='.'):
    """
    Load the key, value pairs from config_dict to opt, overriding existing values in opt
    if there is any.
    """
    if not isinstance(config_dict, dict):
        raise TypeError("Config must be a Python dictionary")
    for k, v in config_dict.items():
        k_parts = k.split(splitter)
        pointer = opt
        for k_part in k_parts[:-1]:
            if k_part not in pointer:
                pointer[k_part] = {}
            pointer = pointer[k_part]
            assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
        ori_value = pointer.get(k_parts[-1])
        pointer[k_parts[-1]] = v
        if ori_value:
            print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")


def load_opt_from_config_file(conf_file):
    """
    Load opt from the config file.

    Args:
        conf_file: config file path

    Returns:
        dict: a dictionary of opt settings
    """
    opt = {}
    with open(conf_file, encoding='utf-8') as f:
        config_dict = yaml.safe_load(f)
        load_config_dict_to_opt(opt, config_dict)

    return opt

def cast_batch_to_dtype(batch, dtype):
    """
    Cast the float32 tensors in a batch to a specified torch dtype.
    It should be called before feeding the batch to the FP16 DeepSpeed model.

    Args:
        batch (torch.tensor or container of torch.tensor): input batch
    Returns:
        return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
    """
    if torch.is_tensor(batch):
        if torch.is_floating_point(batch):
            return_batch = batch.to(dtype)
        else:
            return_batch = batch
    elif isinstance(batch, list):
        return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
    elif isinstance(batch, tuple):
        return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
    elif isinstance(batch, dict):
        return_batch = {}
        for k in batch:
            return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
    else:
        logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
        return_batch = batch

    return return_batch


def cast_batch_to_half(batch):
    """
    Cast the float32 tensors in a batch to float16.
    It should be called before feeding the batch to the FP16 DeepSpeed model.

    Args:
        batch (torch.tensor or container of torch.tensor): input batch
    Returns:
        return_batch: same type as the input batch with internal float32 tensors casted to float16
    """
    return cast_batch_to_dtype(batch, torch.float16)