import importlib import inspect import os from pathlib import Path from .artifact import Artifact, Catalogs from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog from .error_utils import Documentation, UnitxtError, UnitxtWarning from .settings_utils import get_constants, get_settings from .utils import Singleton constants = get_constants() settings = get_settings() def _register_catalog(catalog: LocalCatalog): Catalogs().register(catalog) def _unregister_catalog(catalog: LocalCatalog): Catalogs().unregister(catalog) def is_local_catalog_registered(catalog_path: str): if os.path.isdir(catalog_path): for catalog in _catalogs_list(): if isinstance(catalog, LocalCatalog): if os.path.isdir(catalog.location): if Path(catalog.location).resolve() == Path(catalog_path).resolve(): return True return False def register_local_catalog(catalog_path: str): assert os.path.exists(catalog_path), f"Catalog path {catalog_path} does not exist." assert os.path.isdir( catalog_path ), f"Catalog path {catalog_path} is not a directory." if not is_local_catalog_registered(catalog_path=catalog_path): _register_catalog(LocalCatalog(location=catalog_path)) def unregister_local_catalog(catalog_path: str): if is_local_catalog_registered(catalog_path=catalog_path): for catalog in _catalogs_list(): if isinstance(catalog, LocalCatalog): if os.path.isdir(catalog.location): if Path(catalog.location).resolve() == Path(catalog_path).resolve(): _unregister_catalog(catalog) def _catalogs_list(): return list(Catalogs()) def _register_all_catalogs(): _register_catalog(GithubCatalog()) _register_catalog(LocalCatalog()) _reset_env_local_catalogs() def _reset_env_local_catalogs(): for catalog in _catalogs_list(): if isinstance(catalog, EnvironmentLocalCatalog): _unregister_catalog(catalog) if settings.catalogs and settings.artifactories: raise UnitxtError( f"Both UNITXT_CATALOGS and UNITXT_ARTIFACTORIES are set. Use only UNITXT_CATALOG. UNITXT_ARTIFACTORIES is deprecated.\n" f"UNITXT_CATALOG: {settings.catalogs}\n" f"UNITXT_ARTIFACTORIES: {settings.artifactories}\n", Documentation.CATALOG, ) if settings.artifactories: UnitxtWarning( "UNITXT_ARTIFACTORIES is set but is deprecated, use UNITXT_CATALOGS instead.", Documentation.CATALOG, ) if settings.catalogs: for path in settings.catalogs.split(constants.env_local_catalogs_paths_sep): _register_catalog(EnvironmentLocalCatalog(location=path)) if settings.artifactories: for path in settings.artifactories.split( constants.env_local_catalogs_paths_sep ): _register_catalog(EnvironmentLocalCatalog(location=path)) def _register_all_artifacts(): dir = os.path.dirname(__file__) file_name = os.path.basename(__file__) for file in os.listdir(dir): if ( file.endswith(".py") and file not in constants.non_registered_files and file != file_name ): module_name = file.replace(".py", "") module = importlib.import_module("." + module_name, __package__) for _name, obj in inspect.getmembers(module): # Make sure the object is a class if inspect.isclass(obj): # Make sure the class is a subclass of Artifact (but not Artifact itself) if issubclass(obj, Artifact) and obj is not Artifact: Artifact.register_class(obj) class ProjectArtifactRegisterer(metaclass=Singleton): def __init__(self): if not hasattr(self, "_registered"): self._registered = False if not self._registered: _register_all_catalogs() _register_all_artifacts() self._registered = True def register_all_artifacts(): ProjectArtifactRegisterer()