import logging
import os
import pickle
import random
import zipfile
from typing import Any

import numpy as np
import psutil
import torch

logger = logging.getLogger(__name__)


def set_seed(seed: int = 1234) -> None:
    """Sets the random seed.

    Args:
        seed: seed value
    """

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True


def set_environment(cfg):
    """Sets and checks environment settings"""
    if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "":
        logger.warning("No OpenAI API Key set. Setting metric to BLEU. ")
        cfg.prediction.metric = "BLEU"
    return cfg


def kill_child_processes(parent_pid: int) -> bool:
    """Killing a process and all its child processes

    Args:
        parent_pid: process id of parent

    Returns:
        True or False in case of success or failure
    """

    logger.debug(f"Killing process id: {parent_pid}")

    try:
        parent = psutil.Process(parent_pid)
        if parent.status() == "zombie":
            return False
        children = parent.children(recursive=True)
        for child in children:
            child.kill()
        parent.kill()
        return True
    except psutil.NoSuchProcess:
        logger.warning(f"Cannot kill process id: {parent_pid}. No such process.")
        return False


def kill_ddp_processes() -> None:
    """
    Killing all DDP processes from a single process.
    Firstly kills all children of a single DDP process (dataloader workers)
    Then kills all other DDP processes
    Then kills main parent DDP process
    """

    pid = os.getpid()
    parent_pid = os.getppid()

    current_process = psutil.Process(pid)
    children = current_process.children(recursive=True)
    for child in children:
        child.kill()

    parent_process = psutil.Process(parent_pid)
    children = parent_process.children(recursive=True)[::-1]
    for child in children:
        if child.pid == pid:
            continue
        child.kill()
    parent_process.kill()
    current_process.kill()


def add_file_to_zip(zf: zipfile.ZipFile, path: str) -> None:
    """Adds a file to the existing zip. Does nothing if file does not exist.

    Args:
        zf: zipfile object to add to
        path: path to the file to add
    """

    try:
        zf.write(path, os.path.basename(path))
    except Exception:
        logger.warning(f"File {path} could not be added to zip.")


def save_pickle(path: str, obj: Any, protocol: int = 4) -> None:
    """Saves object as pickle file

    Args:
        path: path of file to save
        obj: object to save
        protocol: protocol to use when saving pickle
    """

    with open(path, "wb") as pickle_file:
        pickle.dump(obj, pickle_file, protocol=protocol)


class DisableLogger:
    def __init__(self, level: int = logging.INFO):
        self.level = level

    def __enter__(self):
        logging.disable(self.level)

    def __exit__(self, exit_type, exit_value, exit_traceback):
        logging.disable(logging.NOTSET)


class PatchedAttribute:
    """
    Patches an attribute of an object for the duration of this context manager.
    Similar to unittest.mock.patch,
    but works also for properties that are not present in the original class

    >>> class MyObj:
    ...     attr = 'original'
    >>> my_obj = MyObj()
    >>> with PatchedAttribute(my_obj, 'attr', 'patched'):
    ...     print(my_obj.attr)
    patched
    >>> print(my_obj.attr)
    original
    >>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'):
    ...     print(my_obj.new_attr)
    new_patched
    >>> assert not hasattr(my_obj, 'new_attr')
    """

    def __init__(self, obj, attribute, new_value):
        self.obj = obj
        self.attribute = attribute
        self.new_value = new_value
        self.original_exists = hasattr(obj, attribute)
        if self.original_exists:
            self.original_value = getattr(obj, attribute)

    def __enter__(self):
        setattr(self.obj, self.attribute, self.new_value)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.original_exists:
            setattr(self.obj, self.attribute, self.original_value)
        else:
            delattr(self.obj, self.attribute)