Spaces:
Sleeping
Sleeping
""" | |
Routines for loading DeepSpeech model. | |
""" | |
__all__ = ['get_deepspeech_model_file'] | |
import os | |
import zipfile | |
import logging | |
import hashlib | |
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' | |
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): | |
""" | |
Return location for the pretrained on local file system. This function will download from online model zoo when | |
model cannot be found or has mismatch. The root directory will be created if it doesn't exist. | |
Parameters | |
---------- | |
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models | |
Location for keeping the model parameters. | |
Returns | |
------- | |
file_path | |
Path to the requested pretrained model file. | |
""" | |
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" | |
file_name = "deepspeech-0_1_0-b90017e8.pb" | |
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) | |
file_path = os.path.join(local_model_store_dir_path, file_name) | |
if os.path.exists(file_path): | |
if _check_sha1(file_path, sha1_hash): | |
return file_path | |
else: | |
logging.warning("Mismatch in the content of model file detected. Downloading again.") | |
else: | |
logging.info("Model file not found. Downloading to {}.".format(file_path)) | |
if not os.path.exists(local_model_store_dir_path): | |
os.makedirs(local_model_store_dir_path) | |
zip_file_path = file_path + ".zip" | |
_download( | |
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( | |
repo_url=deepspeech_features_repo_url, | |
repo_release_tag="v0.0.1", | |
file_name=file_name), | |
path=zip_file_path, | |
overwrite=True) | |
with zipfile.ZipFile(zip_file_path) as zf: | |
zf.extractall(local_model_store_dir_path) | |
os.remove(zip_file_path) | |
if _check_sha1(file_path, sha1_hash): | |
return file_path | |
else: | |
raise ValueError("Downloaded file has different hash. Please try again.") | |
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): | |
""" | |
Download an given URL | |
Parameters | |
---------- | |
url : str | |
URL to download | |
path : str, optional | |
Destination path to store downloaded file. By default stores to the | |
current directory with same name as in url. | |
overwrite : bool, optional | |
Whether to overwrite destination file if already exists. | |
sha1_hash : str, optional | |
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified | |
but doesn't match. | |
retries : integer, default 5 | |
The number of times to attempt the download in case of failure or non 200 return codes | |
verify_ssl : bool, default True | |
Verify SSL certificates. | |
Returns | |
------- | |
str | |
The file path of the downloaded file. | |
""" | |
import warnings | |
try: | |
import requests | |
except ImportError: | |
class requests_failed_to_import(object): | |
pass | |
requests = requests_failed_to_import | |
if path is None: | |
fname = url.split("/")[-1] | |
# Empty filenames are invalid | |
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." | |
else: | |
path = os.path.expanduser(path) | |
if os.path.isdir(path): | |
fname = os.path.join(path, url.split("/")[-1]) | |
else: | |
fname = path | |
assert retries >= 0, "Number of retries should be at least 0" | |
if not verify_ssl: | |
warnings.warn( | |
"Unverified HTTPS request is being made (verify_ssl=False). " | |
"Adding certificate verification is strongly advised.") | |
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): | |
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
while retries + 1 > 0: | |
# Disable pyling too broad Exception | |
# pylint: disable=W0703 | |
try: | |
print("Downloading {} from {}...".format(fname, url)) | |
r = requests.get(url, stream=True, verify=verify_ssl) | |
if r.status_code != 200: | |
raise RuntimeError("Failed downloading url {}".format(url)) | |
with open(fname, "wb") as f: | |
for chunk in r.iter_content(chunk_size=1024): | |
if chunk: # filter out keep-alive new chunks | |
f.write(chunk) | |
if sha1_hash and not _check_sha1(fname, sha1_hash): | |
raise UserWarning("File {} is downloaded but the content hash does not match." | |
" The repo may be outdated or download may be incomplete. " | |
"If the `repo_url` is overridden, consider switching to " | |
"the default repo.".format(fname)) | |
break | |
except Exception as e: | |
retries -= 1 | |
if retries <= 0: | |
raise e | |
else: | |
print("download failed, retrying, {} attempt{} left" | |
.format(retries, "s" if retries > 1 else "")) | |
return fname | |
def _check_sha1(filename, sha1_hash): | |
""" | |
Check whether the sha1 hash of the file content matches the expected hash. | |
Parameters | |
---------- | |
filename : str | |
Path to the file. | |
sha1_hash : str | |
Expected sha1 hash in hexadecimal digits. | |
Returns | |
------- | |
bool | |
Whether the file content matches the expected hash. | |
""" | |
sha1 = hashlib.sha1() | |
with open(filename, "rb") as f: | |
while True: | |
data = f.read(1048576) | |
if not data: | |
break | |
sha1.update(data) | |
return sha1.hexdigest() == sha1_hash | |