metric / catalog.py
Elron's picture
Upload folder using huggingface_hub
82055e6 verified
import json
import os
from collections import Counter
from pathlib import Path
from typing import Optional
import requests
from .artifact import (
AbstractCatalog,
Artifact,
ArtifactLink,
Catalogs,
get_catalog_name_and_args,
reset_artifacts_json_cache,
verify_legal_catalog_name,
)
from .logging_utils import get_logger
from .settings_utils import get_constants
from .text_utils import print_dict
from .version import version
logger = get_logger()
constants = get_constants()
class Catalog(AbstractCatalog):
name: str = None
location: str = None
def __repr__(self):
return f"{self.location}"
class LocalCatalog(Catalog):
name: str = "local"
location: str = constants.default_catalog_path
is_local: bool = True
def path(self, artifact_identifier: str):
assert (
artifact_identifier.strip()
), "artifact_identifier should not be an empty string."
parts = artifact_identifier.split(constants.catalog_hierarchy_sep)
parts[-1] = parts[-1] + ".json"
return os.path.join(self.location, *parts)
def load(self, artifact_identifier: str, overwrite_args=None):
assert (
artifact_identifier in self
), f"Artifact with name {artifact_identifier} does not exist"
path = self.path(artifact_identifier)
return Artifact.load(
path,
artifact_identifier=artifact_identifier,
overwrite_args=overwrite_args,
)
def __getitem__(self, name) -> Artifact:
return self.load(name)
def get_with_overwrite(self, name, overwrite_args):
return self.load(name, overwrite_args=overwrite_args)
def __contains__(self, artifact_identifier: str):
if not os.path.exists(self.location):
return False
path = self.path(artifact_identifier)
if path is None:
return False
return os.path.exists(path) and os.path.isfile(path)
def save_artifact(
self,
artifact: Artifact,
artifact_identifier: str,
overwrite: bool = False,
verbose: bool = True,
):
assert isinstance(
artifact, Artifact
), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
if not overwrite:
assert (
artifact_identifier not in self
), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
path = self.path(artifact_identifier)
os.makedirs(Path(path).parent.absolute(), exist_ok=True)
artifact.save(path)
if verbose:
logger.info(f"Artifact {artifact_identifier} saved to {path}")
class EnvironmentLocalCatalog(LocalCatalog):
pass
class GithubCatalog(LocalCatalog):
name = "community"
repo = "unitxt"
repo_dir = "src/unitxt/catalog"
user = "IBM"
is_local: bool = False
def prepare(self):
tag = version
self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"
def load(self, artifact_identifier: str, overwrite_args=None):
url = self.path(artifact_identifier)
response = requests.get(url)
data = response.json()
new_artifact = Artifact.from_dict(data, overwrite_args=overwrite_args)
new_artifact.__id__ = artifact_identifier
return new_artifact
def __contains__(self, artifact_identifier: str):
url = self.path(artifact_identifier)
response = requests.head(url)
return response.status_code == 200
def add_to_catalog(
artifact: Artifact,
name: str,
catalog: Catalog = None,
overwrite: bool = False,
catalog_path: Optional[str] = None,
verbose=True,
):
reset_artifacts_json_cache()
if catalog is None:
if catalog_path is None:
catalog_path = constants.default_catalog_path
catalog = LocalCatalog(location=catalog_path)
verify_legal_catalog_name(name)
catalog.save_artifact(artifact, name, overwrite=overwrite, verbose=verbose)
def add_link_to_catalog(
artifact_linked_to: str,
name: str,
deprecate: bool = False,
catalog: Catalog = None,
overwrite: bool = False,
catalog_path: Optional[str] = None,
verbose=True,
):
if deprecate:
deprecated_msg = f"Artifact '{name}' is deprecated. Artifact '{artifact_linked_to}' will be instantiated instead. "
deprecated_msg += f"In future uses, please reference artifact '{artifact_linked_to}' directly."
else:
deprecated_msg = None
artifact_link = ArtifactLink(
to=artifact_linked_to, __deprecated_msg__=deprecated_msg
)
add_to_catalog(
artifact=artifact_link,
name=name,
catalog=catalog,
overwrite=overwrite,
catalog_path=catalog_path,
verbose=verbose,
)
def get_from_catalog(
name: str,
catalog: Catalog = None,
catalog_path: Optional[str] = None,
):
if catalog_path is not None:
catalog = LocalCatalog(location=catalog_path)
if catalog is None:
catalogs = None
else:
catalogs = [catalog]
catalog, name, args = get_catalog_name_and_args(name, catalogs=catalogs)
return catalog.get_with_overwrite(
name=name,
overwrite_args=args,
)
def get_local_catalogs_paths():
result = []
for catalog in Catalogs():
if isinstance(catalog, LocalCatalog):
if catalog.is_local:
result.append(catalog.location)
return result
def count_files_recursively(folder):
file_count = 0
for _, _, files in os.walk(folder):
file_count += len(files)
return file_count
def local_catalog_summary(catalog_path):
result = {}
for dir in os.listdir(catalog_path):
if os.path.isdir(os.path.join(catalog_path, dir)):
result[dir] = count_files_recursively(os.path.join(catalog_path, dir))
return result
def summary():
result = Counter()
done = set()
for local_catalog_path in get_local_catalogs_paths():
if local_catalog_path not in done:
result += Counter(local_catalog_summary(local_catalog_path))
done.add(local_catalog_path)
print_dict(result)
return result
def _get_tags_from_file(file_path):
result = Counter()
with open(file_path) as f:
data = json.load(f)
if "__tags__" in data and isinstance(data["__tags__"], dict):
tags = data["__tags__"]
for key, value in tags.items():
if isinstance(value, list):
for item in value:
result[f"{key}:{item}"] += 1
else:
result[f"{key}:{value}"] += 1
return result
def count_tags():
result = Counter()
done = set()
for local_catalog_path in get_local_catalogs_paths():
if local_catalog_path not in done:
for root, _, files in os.walk(local_catalog_path):
for file in files:
if file.endswith(".json"):
file_path = os.path.join(root, file)
try:
result += _get_tags_from_file(file_path)
except json.JSONDecodeError:
logger.info(f"Error decoding JSON in file: {file_path}")
except OSError:
logger.info(f"Error reading file: {file_path}")
done.add(local_catalog_path)
print_dict(result)
return result
def ls(to_file=None):
done = set()
result = []
for local_catalog_path in get_local_catalogs_paths():
if local_catalog_path not in done:
for root, _, files in os.walk(local_catalog_path):
for file in files:
if ".json" not in file:
continue
file_path = os.path.relpath(
os.path.join(root, file), local_catalog_path
)
file_id = ".".join(
file_path.replace(".json", "").split(os.path.sep)
)
result.append(file_id)
if to_file:
with open(to_file, "w+") as f:
f.write("\n".join(result))
else:
logger.info("\n".join(result))
return result