|
import json |
|
import os |
|
from collections import Counter |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import requests |
|
|
|
from .artifact import ( |
|
AbstractCatalog, |
|
Artifact, |
|
ArtifactLink, |
|
Catalogs, |
|
get_catalog_name_and_args, |
|
reset_artifacts_json_cache, |
|
verify_legal_catalog_name, |
|
) |
|
from .logging_utils import get_logger |
|
from .settings_utils import get_constants |
|
from .text_utils import print_dict |
|
from .version import version |
|
|
|
logger = get_logger() |
|
constants = get_constants() |
|
|
|
|
|
class Catalog(AbstractCatalog): |
|
name: str = None |
|
location: str = None |
|
|
|
def __repr__(self): |
|
return f"{self.location}" |
|
|
|
|
|
class LocalCatalog(Catalog): |
|
name: str = "local" |
|
location: str = constants.default_catalog_path |
|
is_local: bool = True |
|
|
|
def path(self, artifact_identifier: str): |
|
assert ( |
|
artifact_identifier.strip() |
|
), "artifact_identifier should not be an empty string." |
|
parts = artifact_identifier.split(constants.catalog_hierarchy_sep) |
|
parts[-1] = parts[-1] + ".json" |
|
return os.path.join(self.location, *parts) |
|
|
|
def load(self, artifact_identifier: str, overwrite_args=None): |
|
assert ( |
|
artifact_identifier in self |
|
), f"Artifact with name {artifact_identifier} does not exist" |
|
path = self.path(artifact_identifier) |
|
return Artifact.load( |
|
path, |
|
artifact_identifier=artifact_identifier, |
|
overwrite_args=overwrite_args, |
|
) |
|
|
|
def __getitem__(self, name) -> Artifact: |
|
return self.load(name) |
|
|
|
def get_with_overwrite(self, name, overwrite_args): |
|
return self.load(name, overwrite_args=overwrite_args) |
|
|
|
def __contains__(self, artifact_identifier: str): |
|
if not os.path.exists(self.location): |
|
return False |
|
path = self.path(artifact_identifier) |
|
if path is None: |
|
return False |
|
return os.path.exists(path) and os.path.isfile(path) |
|
|
|
def save_artifact( |
|
self, |
|
artifact: Artifact, |
|
artifact_identifier: str, |
|
overwrite: bool = False, |
|
verbose: bool = True, |
|
): |
|
assert isinstance( |
|
artifact, Artifact |
|
), f"Input artifact must be an instance of Artifact, got {type(artifact)}" |
|
if not overwrite: |
|
assert ( |
|
artifact_identifier not in self |
|
), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}" |
|
path = self.path(artifact_identifier) |
|
os.makedirs(Path(path).parent.absolute(), exist_ok=True) |
|
artifact.save(path) |
|
if verbose: |
|
logger.info(f"Artifact {artifact_identifier} saved to {path}") |
|
|
|
|
|
class EnvironmentLocalCatalog(LocalCatalog): |
|
pass |
|
|
|
|
|
class GithubCatalog(LocalCatalog): |
|
name = "community" |
|
repo = "unitxt" |
|
repo_dir = "src/unitxt/catalog" |
|
user = "IBM" |
|
is_local: bool = False |
|
|
|
def prepare(self): |
|
tag = version |
|
self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}" |
|
|
|
def load(self, artifact_identifier: str, overwrite_args=None): |
|
url = self.path(artifact_identifier) |
|
response = requests.get(url) |
|
data = response.json() |
|
new_artifact = Artifact.from_dict(data, overwrite_args=overwrite_args) |
|
new_artifact.__id__ = artifact_identifier |
|
return new_artifact |
|
|
|
def __contains__(self, artifact_identifier: str): |
|
url = self.path(artifact_identifier) |
|
response = requests.head(url) |
|
return response.status_code == 200 |
|
|
|
|
|
def add_to_catalog( |
|
artifact: Artifact, |
|
name: str, |
|
catalog: Catalog = None, |
|
overwrite: bool = False, |
|
catalog_path: Optional[str] = None, |
|
verbose=True, |
|
): |
|
reset_artifacts_json_cache() |
|
if catalog is None: |
|
if catalog_path is None: |
|
catalog_path = constants.default_catalog_path |
|
catalog = LocalCatalog(location=catalog_path) |
|
verify_legal_catalog_name(name) |
|
catalog.save_artifact(artifact, name, overwrite=overwrite, verbose=verbose) |
|
|
|
|
|
def add_link_to_catalog( |
|
artifact_linked_to: str, |
|
name: str, |
|
deprecate: bool = False, |
|
catalog: Catalog = None, |
|
overwrite: bool = False, |
|
catalog_path: Optional[str] = None, |
|
verbose=True, |
|
): |
|
if deprecate: |
|
deprecated_msg = f"Artifact '{name}' is deprecated. Artifact '{artifact_linked_to}' will be instantiated instead. " |
|
deprecated_msg += f"In future uses, please reference artifact '{artifact_linked_to}' directly." |
|
else: |
|
deprecated_msg = None |
|
|
|
artifact_link = ArtifactLink( |
|
to=artifact_linked_to, __deprecated_msg__=deprecated_msg |
|
) |
|
|
|
add_to_catalog( |
|
artifact=artifact_link, |
|
name=name, |
|
catalog=catalog, |
|
overwrite=overwrite, |
|
catalog_path=catalog_path, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
def get_from_catalog( |
|
name: str, |
|
catalog: Catalog = None, |
|
catalog_path: Optional[str] = None, |
|
): |
|
if catalog_path is not None: |
|
catalog = LocalCatalog(location=catalog_path) |
|
|
|
if catalog is None: |
|
catalogs = None |
|
else: |
|
catalogs = [catalog] |
|
|
|
catalog, name, args = get_catalog_name_and_args(name, catalogs=catalogs) |
|
|
|
return catalog.get_with_overwrite( |
|
name=name, |
|
overwrite_args=args, |
|
) |
|
|
|
|
|
def get_local_catalogs_paths(): |
|
result = [] |
|
for catalog in Catalogs(): |
|
if isinstance(catalog, LocalCatalog): |
|
if catalog.is_local: |
|
result.append(catalog.location) |
|
return result |
|
|
|
|
|
def count_files_recursively(folder): |
|
file_count = 0 |
|
for _, _, files in os.walk(folder): |
|
file_count += len(files) |
|
return file_count |
|
|
|
|
|
def local_catalog_summary(catalog_path): |
|
result = {} |
|
|
|
for dir in os.listdir(catalog_path): |
|
if os.path.isdir(os.path.join(catalog_path, dir)): |
|
result[dir] = count_files_recursively(os.path.join(catalog_path, dir)) |
|
|
|
return result |
|
|
|
|
|
def summary(): |
|
result = Counter() |
|
done = set() |
|
for local_catalog_path in get_local_catalogs_paths(): |
|
if local_catalog_path not in done: |
|
result += Counter(local_catalog_summary(local_catalog_path)) |
|
done.add(local_catalog_path) |
|
print_dict(result) |
|
return result |
|
|
|
|
|
def _get_tags_from_file(file_path): |
|
result = Counter() |
|
|
|
with open(file_path) as f: |
|
data = json.load(f) |
|
if "__tags__" in data and isinstance(data["__tags__"], dict): |
|
tags = data["__tags__"] |
|
for key, value in tags.items(): |
|
if isinstance(value, list): |
|
for item in value: |
|
result[f"{key}:{item}"] += 1 |
|
else: |
|
result[f"{key}:{value}"] += 1 |
|
|
|
return result |
|
|
|
|
|
def count_tags(): |
|
result = Counter() |
|
done = set() |
|
|
|
for local_catalog_path in get_local_catalogs_paths(): |
|
if local_catalog_path not in done: |
|
for root, _, files in os.walk(local_catalog_path): |
|
for file in files: |
|
if file.endswith(".json"): |
|
file_path = os.path.join(root, file) |
|
try: |
|
result += _get_tags_from_file(file_path) |
|
except json.JSONDecodeError: |
|
logger.info(f"Error decoding JSON in file: {file_path}") |
|
except OSError: |
|
logger.info(f"Error reading file: {file_path}") |
|
|
|
done.add(local_catalog_path) |
|
|
|
print_dict(result) |
|
return result |
|
|
|
|
|
def ls(to_file=None): |
|
done = set() |
|
result = [] |
|
for local_catalog_path in get_local_catalogs_paths(): |
|
if local_catalog_path not in done: |
|
for root, _, files in os.walk(local_catalog_path): |
|
for file in files: |
|
if ".json" not in file: |
|
continue |
|
file_path = os.path.relpath( |
|
os.path.join(root, file), local_catalog_path |
|
) |
|
file_id = ".".join( |
|
file_path.replace(".json", "").split(os.path.sep) |
|
) |
|
result.append(file_id) |
|
if to_file: |
|
with open(to_file, "w+") as f: |
|
f.write("\n".join(result)) |
|
else: |
|
logger.info("\n".join(result)) |
|
return result |
|
|