import json import os from collections import Counter from functools import lru_cache from pathlib import Path from typing import Optional import requests from .artifact import ( AbstractCatalog, Artifact, ArtifactLink, Catalogs, get_catalog_name_and_args, reset_artifacts_json_cache, verify_legal_catalog_name, ) from .logging_utils import get_logger from .settings_utils import get_constants from .text_utils import print_dict from .version import version logger = get_logger() constants = get_constants() class Catalog(AbstractCatalog): name: str = None location: str = None def __repr__(self): return f"{self.location}" class LocalCatalog(Catalog): name: str = "local" location: str = constants.default_catalog_path is_local: bool = True def path(self, artifact_identifier: str): assert ( artifact_identifier.strip() ), "artifact_identifier should not be an empty string." parts = artifact_identifier.split(constants.catalog_hierarchy_sep) parts[-1] = parts[-1] + ".json" return os.path.join(self.location, *parts) def load(self, artifact_identifier: str, overwrite_args=None): assert ( artifact_identifier in self ), f"Artifact with name {artifact_identifier} does not exist" path = self.path(artifact_identifier) return Artifact.load( path, artifact_identifier=artifact_identifier, overwrite_args=overwrite_args, ) def __getitem__(self, name) -> Artifact: return self.load(name) def get_with_overwrite(self, name, overwrite_args): return self.load(name, overwrite_args=overwrite_args) def __contains__(self, artifact_identifier: str): if not os.path.exists(self.location): return False path = self.path(artifact_identifier) if path is None: return False return os.path.exists(path) and os.path.isfile(path) def save_artifact( self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False, verbose: bool = True, ): assert isinstance( artifact, Artifact ), f"Input artifact must be an instance of Artifact, got {type(artifact)}" if not overwrite: assert ( artifact_identifier not in self ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}" path = self.path(artifact_identifier) os.makedirs(Path(path).parent.absolute(), exist_ok=True) artifact.save(path) if verbose: logger.info(f"Artifact {artifact_identifier} saved to {path}") class EnvironmentLocalCatalog(LocalCatalog): pass class GithubCatalog(LocalCatalog): name = "community" repo = "unitxt" repo_dir = "src/unitxt/catalog" user = "IBM" is_local: bool = False def prepare(self): tag = version self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}" def load(self, artifact_identifier: str, overwrite_args=None): url = self.path(artifact_identifier) response = requests.get(url) data = response.json() new_artifact = Artifact.from_dict(data, overwrite_args=overwrite_args) new_artifact.__id__ = artifact_identifier return new_artifact def __contains__(self, artifact_identifier: str): url = self.path(artifact_identifier) response = requests.head(url) return response.status_code == 200 def add_to_catalog( artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False, catalog_path: Optional[str] = None, verbose=True, ): reset_artifacts_json_cache() if catalog is None: if catalog_path is None: catalog_path = constants.default_catalog_path catalog = LocalCatalog(location=catalog_path) verify_legal_catalog_name(name) catalog.save_artifact(artifact, name, overwrite=overwrite, verbose=verbose) def add_link_to_catalog( artifact_linked_to: str, name: str, deprecate: bool = False, catalog: Catalog = None, overwrite: bool = False, catalog_path: Optional[str] = None, verbose=True, ): if deprecate: deprecated_msg = f"Artifact '{name}' is deprecated. Artifact '{artifact_linked_to}' will be instantiated instead. " deprecated_msg += f"In future uses, please reference artifact '{artifact_linked_to}' directly." else: deprecated_msg = None artifact_link = ArtifactLink( artifact_linked_to=artifact_linked_to, __deprecated_msg__=deprecated_msg ) add_to_catalog( artifact=artifact_link, name=name, catalog=catalog, overwrite=overwrite, catalog_path=catalog_path, verbose=verbose, ) @lru_cache(maxsize=None) def get_from_catalog( name: str, catalog: Catalog = None, catalog_path: Optional[str] = None, ): if catalog_path is not None: catalog = LocalCatalog(location=catalog_path) if catalog is None: catalogs = None else: catalogs = [catalog] catalog, name, args = get_catalog_name_and_args(name, catalogs=catalogs) return catalog.get_with_overwrite( name=name, overwrite_args=args, ) def get_local_catalogs_paths(): result = [] for catalog in Catalogs(): if isinstance(catalog, LocalCatalog): if catalog.is_local: result.append(catalog.location) return result def count_files_recursively(folder): file_count = 0 for _, _, files in os.walk(folder): file_count += len(files) return file_count def local_catalog_summary(catalog_path): result = {} for dir in os.listdir(catalog_path): if os.path.isdir(os.path.join(catalog_path, dir)): result[dir] = count_files_recursively(os.path.join(catalog_path, dir)) return result def summary(): result = Counter() done = set() for local_catalog_path in get_local_catalogs_paths(): if local_catalog_path not in done: result += Counter(local_catalog_summary(local_catalog_path)) done.add(local_catalog_path) print_dict(result) return result def _get_tags_from_file(file_path): result = Counter() with open(file_path) as f: data = json.load(f) if "__tags__" in data and isinstance(data["__tags__"], dict): tags = data["__tags__"] for key, value in tags.items(): if isinstance(value, list): for item in value: result[f"{key}:{item}"] += 1 else: result[f"{key}:{value}"] += 1 return result def count_tags(): result = Counter() done = set() for local_catalog_path in get_local_catalogs_paths(): if local_catalog_path not in done: for root, _, files in os.walk(local_catalog_path): for file in files: if file.endswith(".json"): file_path = os.path.join(root, file) try: result += _get_tags_from_file(file_path) except json.JSONDecodeError: logger.info(f"Error decoding JSON in file: {file_path}") except OSError: logger.info(f"Error reading file: {file_path}") done.add(local_catalog_path) print_dict(result) return result def ls(to_file=None): done = set() result = [] for local_catalog_path in get_local_catalogs_paths(): if local_catalog_path not in done: for root, _, files in os.walk(local_catalog_path): for file in files: if ".json" not in file: continue file_path = os.path.relpath( os.path.join(root, file), local_catalog_path ) file_id = ".".join( file_path.replace(".json", "").split(os.path.sep) ) result.append(file_id) if to_file: with open(to_file, "w+") as f: f.write("\n".join(result)) else: logger.info("\n".join(result)) return result