Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Import utilities: Utilities related to imports and our lazy inits. | |
""" | |
import importlib.metadata | |
import importlib.util | |
import json | |
import os | |
import shutil | |
import subprocess | |
import sys | |
import warnings | |
from collections import OrderedDict | |
from functools import lru_cache, wraps | |
from itertools import chain | |
from types import ModuleType | |
from typing import Any, Tuple, Union | |
from packaging import version | |
from transformers import logging | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. | |
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: | |
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version | |
package_exists = importlib.util.find_spec(pkg_name) is not None | |
package_version = "N/A" | |
if package_exists: | |
try: | |
package_version = importlib.metadata.version(pkg_name) | |
package_exists = True | |
except importlib.metadata.PackageNotFoundError: | |
package_exists = False | |
logger.debug(f"Detected {pkg_name} version {package_version}") | |
if return_version: | |
return package_exists, package_version | |
else: | |
return package_exists | |
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) | |
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() | |
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() | |
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. | |
TORCH_FX_REQUIRED_VERSION = version.parse("1.10") | |
ACCELERATE_MIN_VERSION = "0.21.0" | |
FSDP_MIN_VERSION = "1.12.0" | |
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) | |
_apex_available = _is_package_available("apex") | |
_bitsandbytes_available = _is_package_available("bitsandbytes") | |
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. | |
_bs4_available = importlib.util.find_spec("bs4") is not None | |
_coloredlogs_available = _is_package_available("coloredlogs") | |
# `importlib.metadata.util` doesn't work with `opencv-python-headless`. | |
_cv2_available = importlib.util.find_spec("cv2") is not None | |
_datasets_available = _is_package_available("datasets") | |
_decord_available = importlib.util.find_spec("decord") is not None | |
_detectron2_available = _is_package_available("detectron2") | |
# We need to check both `faiss` and `faiss-cpu`. | |
_faiss_available = importlib.util.find_spec("faiss") is not None | |
try: | |
_faiss_version = importlib.metadata.version("faiss") | |
logger.debug(f"Successfully imported faiss version {_faiss_version}") | |
except importlib.metadata.PackageNotFoundError: | |
try: | |
_faiss_version = importlib.metadata.version("faiss-cpu") | |
logger.debug(f"Successfully imported faiss version {_faiss_version}") | |
except importlib.metadata.PackageNotFoundError: | |
_faiss_available = False | |
_ftfy_available = _is_package_available("ftfy") | |
_g2p_en_available = _is_package_available("g2p_en") | |
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) | |
_jieba_available = _is_package_available("jieba") | |
_jinja_available = _is_package_available("jinja2") | |
_kenlm_available = _is_package_available("kenlm") | |
_keras_nlp_available = _is_package_available("keras_nlp") | |
_levenshtein_available = _is_package_available("Levenshtein") | |
_librosa_available = _is_package_available("librosa") | |
_natten_available = _is_package_available("natten") | |
_nltk_available = _is_package_available("nltk") | |
_onnx_available = _is_package_available("onnx") | |
_openai_available = _is_package_available("openai") | |
_optimum_available = _is_package_available("optimum") | |
_auto_gptq_available = _is_package_available("auto_gptq") | |
# `importlib.metadata.version` doesn't work with `awq` | |
_auto_awq_available = importlib.util.find_spec("awq") is not None | |
_pandas_available = _is_package_available("pandas") | |
_peft_available = _is_package_available("peft") | |
_phonemizer_available = _is_package_available("phonemizer") | |
_psutil_available = _is_package_available("psutil") | |
_py3nvml_available = _is_package_available("py3nvml") | |
_pyctcdecode_available = _is_package_available("pyctcdecode") | |
_pytesseract_available = _is_package_available("pytesseract") | |
_pytest_available = _is_package_available("pytest") | |
_pytorch_quantization_available = _is_package_available("pytorch_quantization") | |
_rjieba_available = _is_package_available("rjieba") | |
_sacremoses_available = _is_package_available("sacremoses") | |
_safetensors_available = _is_package_available("safetensors") | |
_scipy_available = _is_package_available("scipy") | |
_sentencepiece_available = _is_package_available("sentencepiece") | |
_is_seqio_available = _is_package_available("seqio") | |
_sklearn_available = importlib.util.find_spec("sklearn") is not None | |
if _sklearn_available: | |
try: | |
importlib.metadata.version("scikit-learn") | |
except importlib.metadata.PackageNotFoundError: | |
_sklearn_available = False | |
_smdistributed_available = importlib.util.find_spec("smdistributed") is not None | |
_soundfile_available = _is_package_available("soundfile") | |
_spacy_available = _is_package_available("spacy") | |
_sudachipy_available = _is_package_available("sudachipy") | |
_tensorflow_probability_available = _is_package_available("tensorflow_probability") | |
_tensorflow_text_available = _is_package_available("tensorflow_text") | |
_tf2onnx_available = _is_package_available("tf2onnx") | |
_timm_available = _is_package_available("timm") | |
_tokenizers_available = _is_package_available("tokenizers") | |
_torchaudio_available = _is_package_available("torchaudio") | |
_torchdistx_available = _is_package_available("torchdistx") | |
_torchvision_available = _is_package_available("torchvision") | |
_torch_version = "N/A" | |
_torch_available = False | |
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: | |
_torch_available, _torch_version = _is_package_available("torch", return_version=True) | |
else: | |
logger.info("Disabling PyTorch because USE_TF is set") | |
_torch_available = False | |
_tf_version = "N/A" | |
_tf_available = False | |
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: | |
_tf_available = True | |
else: | |
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: | |
# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below | |
# with tensorflow-cpu to make sure it still works! | |
_tf_available = importlib.util.find_spec("tensorflow") is not None | |
if _tf_available: | |
candidates = ( | |
"tensorflow", | |
"tensorflow-cpu", | |
"tensorflow-gpu", | |
"tf-nightly", | |
"tf-nightly-cpu", | |
"tf-nightly-gpu", | |
"tf-nightly-rocm", | |
"intel-tensorflow", | |
"intel-tensorflow-avx512", | |
"tensorflow-rocm", | |
"tensorflow-macos", | |
"tensorflow-aarch64", | |
) | |
_tf_version = None | |
# For the metadata, we have to look for both tensorflow and tensorflow-cpu | |
for pkg in candidates: | |
try: | |
_tf_version = importlib.metadata.version(pkg) | |
break | |
except importlib.metadata.PackageNotFoundError: | |
pass | |
_tf_available = _tf_version is not None | |
if _tf_available: | |
if version.parse(_tf_version) < version.parse("2"): | |
logger.info( | |
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." | |
) | |
_tf_available = False | |
else: | |
logger.info("Disabling Tensorflow because USE_TORCH is set") | |
_essentia_available = importlib.util.find_spec("essentia") is not None | |
try: | |
_essentia_version = importlib.metadata.version("essentia") | |
logger.debug(f"Successfully imported essentia version {_essentia_version}") | |
except importlib.metadata.PackageNotFoundError: | |
_essentia_version = False | |
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None | |
try: | |
_pretty_midi_version = importlib.metadata.version("pretty_midi") | |
logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}") | |
except importlib.metadata.PackageNotFoundError: | |
_pretty_midi_available = False | |
ccl_version = "N/A" | |
_is_ccl_available = ( | |
importlib.util.find_spec("torch_ccl") is not None | |
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None | |
) | |
try: | |
ccl_version = importlib.metadata.version("oneccl_bind_pt") | |
logger.debug(f"Detected oneccl_bind_pt version {ccl_version}") | |
except importlib.metadata.PackageNotFoundError: | |
_is_ccl_available = False | |
_flax_available = False | |
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: | |
_flax_available, _flax_version = _is_package_available("flax", return_version=True) | |
if _flax_available: | |
_jax_available, _jax_version = _is_package_available("jax", return_version=True) | |
if _jax_available: | |
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") | |
else: | |
_flax_available = _jax_available = False | |
_jax_version = _flax_version = "N/A" | |
_torch_fx_available = False | |
if _torch_available: | |
torch_version = version.parse(_torch_version) | |
_torch_fx_available = (torch_version.major, torch_version.minor) >= ( | |
TORCH_FX_REQUIRED_VERSION.major, | |
TORCH_FX_REQUIRED_VERSION.minor, | |
) | |
def is_kenlm_available(): | |
return _kenlm_available | |
def is_cv2_available(): | |
return _cv2_available | |
def is_torch_available(): | |
return _torch_available | |
def get_torch_version(): | |
return _torch_version | |
def is_torch_sdpa_available(): | |
if not is_torch_available(): | |
return False | |
elif _torch_version == "N/A": | |
return False | |
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: | |
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 | |
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 | |
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577 | |
return version.parse(_torch_version) >= version.parse("2.1.1") | |
def is_torchvision_available(): | |
return _torchvision_available | |
def is_pyctcdecode_available(): | |
return _pyctcdecode_available | |
def is_librosa_available(): | |
return _librosa_available | |
def is_essentia_available(): | |
return _essentia_available | |
def is_pretty_midi_available(): | |
return _pretty_midi_available | |
def is_torch_cuda_available(): | |
if is_torch_available(): | |
import torch | |
return torch.cuda.is_available() | |
else: | |
return False | |
def is_torch_mps_available(): | |
if is_torch_available(): | |
import torch | |
if hasattr(torch.backends, "mps"): | |
return torch.backends.mps.is_available() | |
return False | |
def is_torch_bf16_gpu_available(): | |
if not is_torch_available(): | |
return False | |
import torch | |
return torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
def is_torch_bf16_cpu_available(): | |
if not is_torch_available(): | |
return False | |
import torch | |
try: | |
# multiple levels of AttributeError depending on the pytorch version so do them all in one check | |
_ = torch.cpu.amp.autocast | |
except AttributeError: | |
return False | |
return True | |
def is_torch_bf16_available(): | |
# the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util | |
# has become ambiguous and therefore deprecated | |
warnings.warn( | |
"The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available " | |
"or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu", | |
FutureWarning, | |
) | |
return is_torch_bf16_gpu_available() | |
def is_torch_fp16_available_on_device(device): | |
if not is_torch_available(): | |
return False | |
import torch | |
try: | |
x = torch.zeros(2, 2, dtype=torch.float16).to(device) | |
_ = x @ x | |
except: # noqa: E722 | |
# TODO: more precise exception matching, if possible. | |
# most backends should return `RuntimeError` however this is not guaranteed. | |
return False | |
return True | |
def is_torch_bf16_available_on_device(device): | |
if not is_torch_available(): | |
return False | |
import torch | |
if device == "cuda": | |
return is_torch_bf16_gpu_available() | |
try: | |
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) | |
_ = x @ x | |
except: # noqa: E722 | |
# TODO: more precise exception matching, if possible. | |
# most backends should return `RuntimeError` however this is not guaranteed. | |
return False | |
return True | |
def is_torch_tf32_available(): | |
if not is_torch_available(): | |
return False | |
import torch | |
if not torch.cuda.is_available() or torch.version.cuda is None: | |
return False | |
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: | |
return False | |
if int(torch.version.cuda.split(".")[0]) < 11: | |
return False | |
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"): | |
return False | |
return True | |
def is_torch_fx_available(): | |
return _torch_fx_available | |
def is_peft_available(): | |
return _peft_available | |
def is_bs4_available(): | |
return _bs4_available | |
def is_tf_available(): | |
return _tf_available | |
def is_coloredlogs_available(): | |
return _coloredlogs_available | |
def is_tf2onnx_available(): | |
return _tf2onnx_available | |
def is_onnx_available(): | |
return _onnx_available | |
def is_openai_available(): | |
return _openai_available | |
def is_flax_available(): | |
return _flax_available | |
def is_ftfy_available(): | |
return _ftfy_available | |
def is_g2p_en_available(): | |
return _g2p_en_available | |
def is_torch_tpu_available(check_device=True): | |
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" | |
if not _torch_available: | |
return False | |
if importlib.util.find_spec("torch_xla") is not None: | |
if check_device: | |
# We need to check if `xla_device` can be found, will raise a RuntimeError if not | |
try: | |
import torch_xla.core.xla_model as xm | |
_ = xm.xla_device() | |
return True | |
except RuntimeError: | |
return False | |
return True | |
return False | |
def is_torch_neuroncore_available(check_device=True): | |
if importlib.util.find_spec("torch_neuronx") is not None: | |
return is_torch_tpu_available(check_device) | |
return False | |
def is_torch_npu_available(check_device=False): | |
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment" | |
if not _torch_available or importlib.util.find_spec("torch_npu") is None: | |
return False | |
import torch | |
import torch_npu # noqa: F401 | |
if check_device: | |
try: | |
# Will raise a RuntimeError if no NPU is found | |
_ = torch.npu.device_count() | |
return torch.npu.is_available() | |
except RuntimeError: | |
return False | |
return hasattr(torch, "npu") and torch.npu.is_available() | |
def is_torchdynamo_available(): | |
if not is_torch_available(): | |
return False | |
try: | |
import torch._dynamo as dynamo # noqa: F401 | |
return True | |
except Exception: | |
return False | |
def is_torch_compile_available(): | |
if not is_torch_available(): | |
return False | |
import torch | |
# We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against | |
# 2.0 but let's do it later. | |
return hasattr(torch, "compile") | |
def is_torchdynamo_compiling(): | |
if not is_torch_available(): | |
return False | |
try: | |
import torch._dynamo as dynamo # noqa: F401 | |
return dynamo.is_compiling() | |
except Exception: | |
return False | |
def is_torch_tensorrt_fx_available(): | |
if importlib.util.find_spec("torch_tensorrt") is None: | |
return False | |
return importlib.util.find_spec("torch_tensorrt.fx") is not None | |
def is_datasets_available(): | |
return _datasets_available | |
def is_detectron2_available(): | |
return _detectron2_available | |
def is_rjieba_available(): | |
return _rjieba_available | |
def is_psutil_available(): | |
return _psutil_available | |
def is_py3nvml_available(): | |
return _py3nvml_available | |
def is_sacremoses_available(): | |
return _sacremoses_available | |
def is_apex_available(): | |
return _apex_available | |
def is_ninja_available(): | |
r""" | |
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the | |
[ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise. | |
""" | |
try: | |
subprocess.check_output("ninja --version".split()) | |
except Exception: | |
return False | |
else: | |
return True | |
def is_ipex_available(): | |
def get_major_and_minor_from_version(full_version): | |
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) | |
if not is_torch_available() or not _ipex_available: | |
return False | |
torch_major_and_minor = get_major_and_minor_from_version(_torch_version) | |
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) | |
if torch_major_and_minor != ipex_major_and_minor: | |
logger.warning( | |
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," | |
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." | |
) | |
return False | |
return True | |
def is_torch_xpu_available(check_device=False): | |
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" | |
if not is_ipex_available(): | |
return False | |
import intel_extension_for_pytorch # noqa: F401 | |
import torch | |
if check_device: | |
try: | |
# Will raise a RuntimeError if no XPU is found | |
_ = torch.xpu.device_count() | |
return torch.xpu.is_available() | |
except RuntimeError: | |
return False | |
return hasattr(torch, "xpu") and torch.xpu.is_available() | |
def is_bitsandbytes_available(): | |
if not is_torch_available(): | |
return False | |
# bitsandbytes throws an error if cuda is not available | |
# let's avoid that by adding a simple check | |
import torch | |
return _bitsandbytes_available and torch.cuda.is_available() | |
def is_flash_attn_2_available(): | |
if not is_torch_available(): | |
return False | |
if not _is_package_available("flash_attn"): | |
return False | |
# Let's add an extra check to see if cuda is available | |
import torch | |
if not torch.cuda.is_available(): | |
return False | |
if torch.version.cuda: | |
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") | |
elif torch.version.hip: | |
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention | |
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") | |
else: | |
return False | |
def is_flash_attn_greater_or_equal_2_10(): | |
if not _is_package_available("flash_attn"): | |
return False | |
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") | |
def is_flash_attn_available(): | |
logger.warning( | |
"Using `is_flash_attn_available` is deprecated and will be removed in v4.38. " | |
"Please use `is_flash_attn_2_available` instead." | |
) | |
return is_flash_attn_2_available() | |
def is_torchdistx_available(): | |
return _torchdistx_available | |
def is_faiss_available(): | |
return _faiss_available | |
def is_scipy_available(): | |
return _scipy_available | |
def is_sklearn_available(): | |
return _sklearn_available | |
def is_sentencepiece_available(): | |
return _sentencepiece_available | |
def is_seqio_available(): | |
return _is_seqio_available | |
def is_protobuf_available(): | |
if importlib.util.find_spec("google") is None: | |
return False | |
return importlib.util.find_spec("google.protobuf") is not None | |
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION): | |
if min_version is not None: | |
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) | |
return _accelerate_available | |
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION): | |
return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version) | |
def is_optimum_available(): | |
return _optimum_available | |
def is_auto_awq_available(): | |
return _auto_awq_available | |
def is_auto_gptq_available(): | |
return _auto_gptq_available | |
def is_levenshtein_available(): | |
return _levenshtein_available | |
def is_optimum_neuron_available(): | |
return _optimum_available and _is_package_available("optimum.neuron") | |
def is_safetensors_available(): | |
return _safetensors_available | |
def is_tokenizers_available(): | |
return _tokenizers_available | |
def is_vision_available(): | |
_pil_available = importlib.util.find_spec("PIL") is not None | |
if _pil_available: | |
try: | |
package_version = importlib.metadata.version("Pillow") | |
except importlib.metadata.PackageNotFoundError: | |
try: | |
package_version = importlib.metadata.version("Pillow-SIMD") | |
except importlib.metadata.PackageNotFoundError: | |
return False | |
logger.debug(f"Detected PIL version {package_version}") | |
return _pil_available | |
def is_pytesseract_available(): | |
return _pytesseract_available | |
def is_pytest_available(): | |
return _pytest_available | |
def is_spacy_available(): | |
return _spacy_available | |
def is_tensorflow_text_available(): | |
return is_tf_available() and _tensorflow_text_available | |
def is_keras_nlp_available(): | |
return is_tensorflow_text_available() and _keras_nlp_available | |
def is_in_notebook(): | |
try: | |
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py | |
get_ipython = sys.modules["IPython"].get_ipython | |
if "IPKernelApp" not in get_ipython().config: | |
raise ImportError("console") | |
if "VSCODE_PID" in os.environ: | |
raise ImportError("vscode") | |
if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0": | |
# Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook | |
# https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel | |
raise ImportError("databricks") | |
return importlib.util.find_spec("IPython") is not None | |
except (AttributeError, ImportError, KeyError): | |
return False | |
def is_pytorch_quantization_available(): | |
return _pytorch_quantization_available | |
def is_tensorflow_probability_available(): | |
return _tensorflow_probability_available | |
def is_pandas_available(): | |
return _pandas_available | |
def is_sagemaker_dp_enabled(): | |
# Get the sagemaker specific env variable. | |
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") | |
try: | |
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled". | |
sagemaker_params = json.loads(sagemaker_params) | |
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False): | |
return False | |
except json.JSONDecodeError: | |
return False | |
# Lastly, check if the `smdistributed` module is present. | |
return _smdistributed_available | |
def is_sagemaker_mp_enabled(): | |
# Get the sagemaker specific mp parameters from smp_options variable. | |
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") | |
try: | |
# Parse it and check the field "partitions" is included, it is required for model parallel. | |
smp_options = json.loads(smp_options) | |
if "partitions" not in smp_options: | |
return False | |
except json.JSONDecodeError: | |
return False | |
# Get the sagemaker specific framework parameters from mpi_options variable. | |
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") | |
try: | |
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled". | |
mpi_options = json.loads(mpi_options) | |
if not mpi_options.get("sagemaker_mpi_enabled", False): | |
return False | |
except json.JSONDecodeError: | |
return False | |
# Lastly, check if the `smdistributed` module is present. | |
return _smdistributed_available | |
def is_training_run_on_sagemaker(): | |
return "SAGEMAKER_JOB_NAME" in os.environ | |
def is_soundfile_availble(): | |
return _soundfile_available | |
def is_timm_available(): | |
return _timm_available | |
def is_natten_available(): | |
return _natten_available | |
def is_nltk_available(): | |
return _nltk_available | |
def is_torchaudio_available(): | |
return _torchaudio_available | |
def is_speech_available(): | |
# For now this depends on torchaudio but the exact dependency might evolve in the future. | |
return _torchaudio_available | |
def is_phonemizer_available(): | |
return _phonemizer_available | |
def torch_only_method(fn): | |
def wrapper(*args, **kwargs): | |
if not _torch_available: | |
raise ImportError( | |
"You need to install pytorch to use this method or class, " | |
"or activate it with environment variables USE_TORCH=1 and USE_TF=0." | |
) | |
else: | |
return fn(*args, **kwargs) | |
return wrapper | |
def is_ccl_available(): | |
return _is_ccl_available | |
def is_decord_available(): | |
return _decord_available | |
def is_sudachi_available(): | |
return _sudachipy_available | |
def is_jumanpp_available(): | |
return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None) | |
def is_cython_available(): | |
return importlib.util.find_spec("pyximport") is not None | |
def is_jieba_available(): | |
return _jieba_available | |
def is_jinja_available(): | |
return _jinja_available | |
# docstyle-ignore | |
CV2_IMPORT_ERROR = """ | |
{0} requires the OpenCV library but it was not found in your environment. You can install it with: | |
``` | |
pip install opencv-python | |
``` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
DATASETS_IMPORT_ERROR = """ | |
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: | |
``` | |
pip install datasets | |
``` | |
In a notebook or a colab, you can install it by executing a cell with | |
``` | |
!pip install datasets | |
``` | |
then restarting your kernel. | |
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current | |
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or | |
that python file if that's the case. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TOKENIZERS_IMPORT_ERROR = """ | |
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: | |
``` | |
pip install tokenizers | |
``` | |
In a notebook or a colab, you can install it by executing a cell with | |
``` | |
!pip install tokenizers | |
``` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
SENTENCEPIECE_IMPORT_ERROR = """ | |
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the | |
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones | |
that match your environment. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PROTOBUF_IMPORT_ERROR = """ | |
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the | |
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones | |
that match your environment. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
FAISS_IMPORT_ERROR = """ | |
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the | |
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones | |
that match your environment. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PYTORCH_IMPORT_ERROR = """ | |
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the | |
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TORCHVISION_IMPORT_ERROR = """ | |
{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the | |
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PYTORCH_IMPORT_ERROR_WITH_TF = """ | |
{0} requires the PyTorch library but it was not found in your environment. | |
However, we were able to find a TensorFlow installation. TensorFlow classes begin | |
with "TF", but are otherwise identically named to our PyTorch classes. This | |
means that the TF equivalent of the class you tried to import would be "TF{0}". | |
If you want to use TensorFlow, please use TF classes instead! | |
If you really do want to use PyTorch please go to | |
https://pytorch.org/get-started/locally/ and follow the instructions that | |
match your environment. | |
""" | |
# docstyle-ignore | |
TF_IMPORT_ERROR_WITH_PYTORCH = """ | |
{0} requires the TensorFlow library but it was not found in your environment. | |
However, we were able to find a PyTorch installation. PyTorch classes do not begin | |
with "TF", but are otherwise identically named to our TF classes. | |
If you want to use PyTorch, please use those classes instead! | |
If you really do want to use TensorFlow, please follow the instructions on the | |
installation page https://www.tensorflow.org/install that match your environment. | |
""" | |
# docstyle-ignore | |
BS4_IMPORT_ERROR = """ | |
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: | |
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
SKLEARN_IMPORT_ERROR = """ | |
{0} requires the scikit-learn library but it was not found in your environment. You can install it with: | |
``` | |
pip install -U scikit-learn | |
``` | |
In a notebook or a colab, you can install it by executing a cell with | |
``` | |
!pip install -U scikit-learn | |
``` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TENSORFLOW_IMPORT_ERROR = """ | |
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the | |
installation page: https://www.tensorflow.org/install and follow the ones that match your environment. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
DETECTRON2_IMPORT_ERROR = """ | |
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the | |
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones | |
that match your environment. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
FLAX_IMPORT_ERROR = """ | |
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the | |
installation page: https://github.com/google/flax and follow the ones that match your environment. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
FTFY_IMPORT_ERROR = """ | |
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the | |
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones | |
that match your environment. Please note that you may need to restart your runtime after installation. | |
""" | |
LEVENSHTEIN_IMPORT_ERROR = """ | |
{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip | |
install python-Levenshtein`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
G2P_EN_IMPORT_ERROR = """ | |
{0} requires the g2p-en library but it was not found in your environment. You can install it with pip: | |
`pip install g2p-en`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PYTORCH_QUANTIZATION_IMPORT_ERROR = """ | |
{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip: | |
`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ | |
{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as | |
explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TENSORFLOW_TEXT_IMPORT_ERROR = """ | |
{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as | |
explained here: https://www.tensorflow.org/text/guide/tf_text_intro. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PANDAS_IMPORT_ERROR = """ | |
{0} requires the pandas library but it was not found in your environment. You can install it with pip as | |
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html. | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PHONEMIZER_IMPORT_ERROR = """ | |
{0} requires the phonemizer library but it was not found in your environment. You can install it with pip: | |
`pip install phonemizer`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
SACREMOSES_IMPORT_ERROR = """ | |
{0} requires the sacremoses library but it was not found in your environment. You can install it with pip: | |
`pip install sacremoses`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
SCIPY_IMPORT_ERROR = """ | |
{0} requires the scipy library but it was not found in your environment. You can install it with pip: | |
`pip install scipy`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
SPEECH_IMPORT_ERROR = """ | |
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip: | |
`pip install torchaudio`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
TIMM_IMPORT_ERROR = """ | |
{0} requires the timm library but it was not found in your environment. You can install it with pip: | |
`pip install timm`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
NATTEN_IMPORT_ERROR = """ | |
{0} requires the natten library but it was not found in your environment. You can install it by referring to: | |
shi-labs.com/natten . You can also install it with pip (may take longer to build): | |
`pip install natten`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
NLTK_IMPORT_ERROR = """ | |
{0} requires the NLTK library but it was not found in your environment. You can install it by referring to: | |
https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
VISION_IMPORT_ERROR = """ | |
{0} requires the PIL library but it was not found in your environment. You can install it with pip: | |
`pip install pillow`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PYTESSERACT_IMPORT_ERROR = """ | |
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip: | |
`pip install pytesseract`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PYCTCDECODE_IMPORT_ERROR = """ | |
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip: | |
`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
ACCELERATE_IMPORT_ERROR = """ | |
{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment. | |
You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your | |
runtime after installation. | |
""" | |
# docstyle-ignore | |
CCL_IMPORT_ERROR = """ | |
{0} requires the torch ccl library but it was not found in your environment. You can install it with pip: | |
`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
ESSENTIA_IMPORT_ERROR = """ | |
{0} requires essentia library. But that was not found in your environment. You can install them with pip: | |
`pip install essentia==2.1b6.dev1034` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
LIBROSA_IMPORT_ERROR = """ | |
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip: | |
`pip install librosa` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
# docstyle-ignore | |
PRETTY_MIDI_IMPORT_ERROR = """ | |
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip: | |
`pip install pretty_midi` | |
Please note that you may need to restart your runtime after installation. | |
""" | |
DECORD_IMPORT_ERROR = """ | |
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install | |
decord`. Please note that you may need to restart your runtime after installation. | |
""" | |
CYTHON_IMPORT_ERROR = """ | |
{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install | |
Cython`. Please note that you may need to restart your runtime after installation. | |
""" | |
JIEBA_IMPORT_ERROR = """ | |
{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install | |
jieba`. Please note that you may need to restart your runtime after installation. | |
""" | |
PEFT_IMPORT_ERROR = """ | |
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install | |
peft`. Please note that you may need to restart your runtime after installation. | |
""" | |
JINJA_IMPORT_ERROR = """ | |
{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install | |
jinja2`. Please note that you may need to restart your runtime after installation. | |
""" | |
BACKENDS_MAPPING = OrderedDict( | |
[ | |
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), | |
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)), | |
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), | |
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), | |
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), | |
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), | |
("flax", (is_flax_available, FLAX_IMPORT_ERROR)), | |
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), | |
("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), | |
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), | |
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)), | |
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)), | |
("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)), | |
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), | |
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), | |
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), | |
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), | |
("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)), | |
("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), | |
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), | |
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), | |
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), | |
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), | |
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), | |
("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)), | |
("timm", (is_timm_available, TIMM_IMPORT_ERROR)), | |
("natten", (is_natten_available, NATTEN_IMPORT_ERROR)), | |
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), | |
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), | |
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), | |
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), | |
("vision", (is_vision_available, VISION_IMPORT_ERROR)), | |
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), | |
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), | |
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)), | |
("decord", (is_decord_available, DECORD_IMPORT_ERROR)), | |
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), | |
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), | |
("peft", (is_peft_available, PEFT_IMPORT_ERROR)), | |
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), | |
] | |
) | |
def requires_backends(obj, backends): | |
if not isinstance(backends, (list, tuple)): | |
backends = [backends] | |
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ | |
# Raise an error for users who might not realize that classes without "TF" are torch-only | |
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): | |
raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) | |
# Raise the inverse error for PyTorch users trying to load TF classes | |
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): | |
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) | |
checks = (BACKENDS_MAPPING[backend] for backend in backends) | |
failed = [msg.format(name) for available, msg in checks if not available()] | |
if failed: | |
raise ImportError("".join(failed)) | |
class DummyObject(type): | |
""" | |
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by | |
`requires_backend` each time a user tries to access any method of that class. | |
""" | |
def __getattribute__(cls, key): | |
if key.startswith("_") and key != "_from_config": | |
return super().__getattribute__(key) | |
requires_backends(cls, cls._backends) | |
def torch_required(func): | |
warnings.warn( | |
"The method `torch_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.", | |
FutureWarning, | |
) | |
# Chose a different decorator name than in tests so it's clear they are not the same. | |
def wrapper(*args, **kwargs): | |
if is_torch_available(): | |
return func(*args, **kwargs) | |
else: | |
raise ImportError(f"Method `{func.__name__}` requires PyTorch.") | |
return wrapper | |
def tf_required(func): | |
warnings.warn( | |
"The method `tf_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.", | |
FutureWarning, | |
) | |
# Chose a different decorator name than in tests so it's clear they are not the same. | |
def wrapper(*args, **kwargs): | |
if is_tf_available(): | |
return func(*args, **kwargs) | |
else: | |
raise ImportError(f"Method `{func.__name__}` requires TF.") | |
return wrapper | |
def is_torch_fx_proxy(x): | |
if is_torch_fx_available(): | |
import torch.fx | |
return isinstance(x, torch.fx.Proxy) | |
return False | |
class _LazyModule(ModuleType): | |
""" | |
Module class that surfaces all objects but only performs associated imports when the objects are requested. | |
""" | |
# Very heavily inspired by optuna.integration._IntegrationModule | |
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py | |
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): | |
super().__init__(name) | |
self._modules = set(import_structure.keys()) | |
self._class_to_module = {} | |
for key, values in import_structure.items(): | |
for value in values: | |
self._class_to_module[value] = key | |
# Needed for autocompletion in an IDE | |
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) | |
self.__file__ = module_file | |
self.__spec__ = module_spec | |
self.__path__ = [os.path.dirname(module_file)] | |
self._objects = {} if extra_objects is None else extra_objects | |
self._name = name | |
self._import_structure = import_structure | |
# Needed for autocompletion in an IDE | |
def __dir__(self): | |
result = super().__dir__() | |
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether | |
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. | |
for attr in self.__all__: | |
if attr not in result: | |
result.append(attr) | |
return result | |
def __getattr__(self, name: str) -> Any: | |
if name in self._objects: | |
return self._objects[name] | |
if name in self._modules: | |
value = self._get_module(name) | |
elif name in self._class_to_module.keys(): | |
module = self._get_module(self._class_to_module[name]) | |
value = getattr(module, name) | |
else: | |
raise AttributeError(f"module {self.__name__} has no attribute {name}") | |
setattr(self, name, value) | |
return value | |
def _get_module(self, module_name: str): | |
try: | |
return importlib.import_module("." + module_name, self.__name__) | |
except Exception as e: | |
raise RuntimeError( | |
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" | |
f" traceback):\n{e}" | |
) from e | |
def __reduce__(self): | |
return (self.__class__, (self._name, self.__file__, self._import_structure)) | |
class OptionalDependencyNotAvailable(BaseException): | |
"""Internally used error class for signalling an optional dependency was not found.""" | |
def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: | |
"""Imports transformers directly | |
Args: | |
path (`str`): The path to the source file | |
file (`str`, optional): The file to join with the path. Defaults to "__init__.py". | |
Returns: | |
`ModuleType`: The resulting imported module | |
""" | |
name = "transformers" | |
location = os.path.join(path, file) | |
spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path]) | |
module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(module) | |
module = sys.modules[name] | |
return module | |