bullerwins's picture
Upload folder using huggingface_hub
eb00867 verified
raw
history blame
15 kB
import os
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
import requests
from tqdm.auto import tqdm as base_tqdm
from tqdm.contrib.concurrent import thread_map
from . import constants
from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
from .utils import tqdm as hf_tqdm
logger = logging.get_logger(__name__)
@validate_hf_hub_args
def snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
proxies: Optional[Dict] = None,
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
force_download: bool = False,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[base_tqdm] = None,
headers: Optional[Dict[str, str]] = None,
endpoint: Optional[str] = None,
# Deprecated args
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
resume_download: Optional[bool] = None,
) -> str:
"""Download repo files.
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder. You can also filter which files to download using
`allow_patterns` and `ignore_patterns`.
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
cache-system, it's optimized for regularly pulling the latest version of a repository.
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
configured. It is also not possible to filter which files to download when cloning a repository using git.
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
`None` or `"model"` if downloading from a model. Default is `None`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_dir (`str` or `Path`, *optional*):
If provided, the downloaded files will be placed under this directory.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
The version of the library.
user_agent (`str`, `dict`, *optional*):
The user-agent info in the form of a dictionary or a string.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to
`requests.request`.
etag_timeout (`float`, *optional*, defaults to `10`):
When fetching ETag, how many seconds to wait for the server to send
data before giving up which is passed to `requests.request`.
force_download (`bool`, *optional*, defaults to `False`):
Whether the file should be downloaded even if it already exists in the local cache.
token (`str`, `bool`, *optional*):
A token to be used for the download.
- If `True`, the token is read from the HuggingFace config
folder.
- If a string, it's used as the authentication token.
headers (`dict`, *optional*):
Additional headers to include in the request. Those headers take precedence over the others.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are downloaded.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not downloaded.
max_workers (`int`, *optional*):
Number of concurrent threads to download files (1 thread = 1 file download).
Defaults to 8.
tqdm_class (`tqdm`, *optional*):
If provided, overwrites the default behavior for the progress bar. Passed
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
Note that the `tqdm_class` is not passed to each individual download.
Defaults to the custom HF progress bar that can be disabled by setting
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
Returns:
`str`: folder path of the repo snapshot.
Raises:
[`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
[`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
If `token=True` and the token cannot be found.
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid.
"""
if cache_dir is None:
cache_dir = constants.HF_HUB_CACHE
if revision is None:
revision = constants.DEFAULT_REVISION
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if repo_type is None:
repo_type = "model"
if repo_type not in constants.REPO_TYPES:
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
api_call_error: Optional[Exception] = None
if not local_files_only:
# try/except logic to handle different errors => taken from `hf_hub_download`
try:
# if we have internet connection we want to list files to download
api = HfApi(
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
endpoint=endpoint,
headers=headers,
)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
OfflineModeIsEnabled,
) as error:
# Internet connection is down
# => will try to use local files only
api_call_error = error
pass
except RevisionNotFoundError:
# The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
raise
except requests.HTTPError as error:
# Multiple reasons for an http error:
# - Repository is private and invalid/missing token sent
# - Repository is gated and invalid/missing token sent
# - Hub is down (error 500 or 504)
# => let's switch to 'local_files_only=True' to check if the files are already cached.
# (if it's not the case, the error will be re-raised)
api_call_error = error
pass
# At this stage, if `repo_info` is None it means either:
# - internet connection is down
# - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
# - repo is private/gated and invalid/missing token sent
# - Hub is down
# => let's look if we can find the appropriate folder in the cache:
# - if the specified revision is a commit hash, look inside "snapshots".
# - f the specified revision is a branch or tag, look inside "refs".
# => if local_dir is not None, we will return the path to the local folder if it exists.
if repo_info is None:
# Try to get which commit hash corresponds to the specified revision
commit_hash = None
if REGEX_COMMIT_HASH.match(revision):
commit_hash = revision
else:
ref_path = os.path.join(storage_folder, "refs", revision)
if os.path.exists(ref_path):
# retrieve commit_hash from refs file
with open(ref_path) as f:
commit_hash = f.read()
# Try to locate snapshot folder for this commit hash
if commit_hash is not None:
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
if os.path.exists(snapshot_folder):
# Snapshot folder exists => let's return it
# (but we can't check if all the files are actually there)
return snapshot_folder
# If local_dir is not None, return it if it exists and is not empty
if local_dir is not None:
local_dir = Path(local_dir)
if local_dir.is_dir() and any(local_dir.iterdir()):
logger.warning(
f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
)
return str(local_dir.resolve())
# If we couldn't find the appropriate folder on disk, raise an error.
if local_files_only:
raise LocalEntryNotFoundError(
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
"'local_files_only=False' as input."
)
elif isinstance(api_call_error, OfflineModeIsEnabled):
raise LocalEntryNotFoundError(
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
"'HF_HUB_OFFLINE=0' as environment variable."
) from api_call_error
elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
# Repo not found => let's raise the actual error
raise api_call_error
else:
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
raise LocalEntryNotFoundError(
"An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
" snapshot folder for the specified revision on the local disk. Please check your internet connection"
" and try again."
) from api_call_error
# At this stage, internet connection is up and running
# => let's download the files!
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
filtered_repo_files = list(
filter_repo_objects(
items=[f.rfilename for f in repo_info.siblings],
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
commit_hash = repo_info.sha
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
if revision != commit_hash:
ref_path = os.path.join(storage_folder, "refs", revision)
try:
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
with open(ref_path, "w") as f:
f.write(commit_hash)
except OSError as e:
logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
# we pass the commit_hash to hf_hub_download
# so no network call happens if we already
# have the file locally.
def _inner_hf_hub_download(repo_file: str):
return hf_hub_download(
repo_id,
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
endpoint=endpoint,
cache_dir=cache_dir,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
proxies=proxies,
etag_timeout=etag_timeout,
resume_download=resume_download,
force_download=force_download,
token=token,
headers=headers,
)
if constants.HF_HUB_ENABLE_HF_TRANSFER:
# when using hf_transfer we don't want extra parallelism
# from the one hf_transfer provides
for file in filtered_repo_files:
_inner_hf_hub_download(file)
else:
thread_map(
_inner_hf_hub_download,
filtered_repo_files,
desc=f"Fetching {len(filtered_repo_files)} files",
max_workers=max_workers,
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
tqdm_class=tqdm_class or hf_tqdm,
)
if local_dir is not None:
return str(os.path.realpath(local_dir))
return snapshot_folder