from pathlib import Path

from datasets import disable_caching, enable_caching, is_caching_enabled
from datasets.utils.py_utils import get_imports

from .file_utils import get_all_files_in_dir

HF_CACHING_ENABLED = False


class HFCachingContextManager:
    def __init__(self, should_cache):
        self.should_cache = should_cache
        self.was_caching_enabled = None

    def __enter__(self):
        self.was_caching_enabled = is_caching_enabled()
        if self.should_cache:
            if not self.was_caching_enabled:
                enable_caching()
        else:
            if self.was_caching_enabled:
                disable_caching()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.was_caching_enabled != is_caching_enabled():
            if self.was_caching_enabled:
                enable_caching()
            else:
                disable_caching()


def get_missing_imports(file, exclude=[]):
    src_dir = Path(__file__).parent
    python_files = get_all_files_in_dir(src_dir, file_extension=".py")
    # get only the file without the path and extension
    required_modules = [Path(p).stem for p in python_files]
    imports = get_imports(file)
    imported_modules = [i[1] for i in imports if i[0] == "internal"]
    missing_imports = [i for i in required_modules if i not in imported_modules and i not in exclude]
    return missing_imports