|
""" |
|
Routines for loading DeepSpeech model. |
|
""" |
|
|
|
__all__ = ['get_deepspeech_model_file'] |
|
|
|
import os |
|
import zipfile |
|
import logging |
|
import hashlib |
|
|
|
|
|
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' |
|
|
|
|
|
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): |
|
""" |
|
Return location for the pretrained on local file system. This function will download from online model zoo when |
|
model cannot be found or has mismatch. The root directory will be created if it doesn't exist. |
|
|
|
Parameters |
|
---------- |
|
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models |
|
Location for keeping the model parameters. |
|
|
|
Returns |
|
------- |
|
file_path |
|
Path to the requested pretrained model file. |
|
""" |
|
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" |
|
file_name = "deepspeech-0_1_0-b90017e8.pb" |
|
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) |
|
file_path = os.path.join(local_model_store_dir_path, file_name) |
|
if os.path.exists(file_path): |
|
if _check_sha1(file_path, sha1_hash): |
|
return file_path |
|
else: |
|
logging.warning("Mismatch in the content of model file detected. Downloading again.") |
|
else: |
|
logging.info("Model file not found. Downloading to {}.".format(file_path)) |
|
|
|
if not os.path.exists(local_model_store_dir_path): |
|
os.makedirs(local_model_store_dir_path) |
|
|
|
zip_file_path = file_path + ".zip" |
|
_download( |
|
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( |
|
repo_url=deepspeech_features_repo_url, |
|
repo_release_tag="v0.0.1", |
|
file_name=file_name), |
|
path=zip_file_path, |
|
overwrite=True) |
|
with zipfile.ZipFile(zip_file_path) as zf: |
|
zf.extractall(local_model_store_dir_path) |
|
os.remove(zip_file_path) |
|
|
|
if _check_sha1(file_path, sha1_hash): |
|
return file_path |
|
else: |
|
raise ValueError("Downloaded file has different hash. Please try again.") |
|
|
|
|
|
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): |
|
""" |
|
Download an given URL |
|
|
|
Parameters |
|
---------- |
|
url : str |
|
URL to download |
|
path : str, optional |
|
Destination path to store downloaded file. By default stores to the |
|
current directory with same name as in url. |
|
overwrite : bool, optional |
|
Whether to overwrite destination file if already exists. |
|
sha1_hash : str, optional |
|
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified |
|
but doesn't match. |
|
retries : integer, default 5 |
|
The number of times to attempt the download in case of failure or non 200 return codes |
|
verify_ssl : bool, default True |
|
Verify SSL certificates. |
|
|
|
Returns |
|
------- |
|
str |
|
The file path of the downloaded file. |
|
""" |
|
import warnings |
|
try: |
|
import requests |
|
except ImportError: |
|
class requests_failed_to_import(object): |
|
pass |
|
requests = requests_failed_to_import |
|
|
|
if path is None: |
|
fname = url.split("/")[-1] |
|
|
|
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." |
|
else: |
|
path = os.path.expanduser(path) |
|
if os.path.isdir(path): |
|
fname = os.path.join(path, url.split("/")[-1]) |
|
else: |
|
fname = path |
|
assert retries >= 0, "Number of retries should be at least 0" |
|
|
|
if not verify_ssl: |
|
warnings.warn( |
|
"Unverified HTTPS request is being made (verify_ssl=False). " |
|
"Adding certificate verification is strongly advised.") |
|
|
|
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): |
|
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
while retries + 1 > 0: |
|
|
|
|
|
try: |
|
print("Downloading {} from {}...".format(fname, url)) |
|
r = requests.get(url, stream=True, verify=verify_ssl) |
|
if r.status_code != 200: |
|
raise RuntimeError("Failed downloading url {}".format(url)) |
|
with open(fname, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=1024): |
|
if chunk: |
|
f.write(chunk) |
|
if sha1_hash and not _check_sha1(fname, sha1_hash): |
|
raise UserWarning("File {} is downloaded but the content hash does not match." |
|
" The repo may be outdated or download may be incomplete. " |
|
"If the `repo_url` is overridden, consider switching to " |
|
"the default repo.".format(fname)) |
|
break |
|
except Exception as e: |
|
retries -= 1 |
|
if retries <= 0: |
|
raise e |
|
else: |
|
print("download failed, retrying, {} attempt{} left" |
|
.format(retries, "s" if retries > 1 else "")) |
|
|
|
return fname |
|
|
|
|
|
def _check_sha1(filename, sha1_hash): |
|
""" |
|
Check whether the sha1 hash of the file content matches the expected hash. |
|
|
|
Parameters |
|
---------- |
|
filename : str |
|
Path to the file. |
|
sha1_hash : str |
|
Expected sha1 hash in hexadecimal digits. |
|
|
|
Returns |
|
------- |
|
bool |
|
Whether the file content matches the expected hash. |
|
""" |
|
sha1 = hashlib.sha1() |
|
with open(filename, "rb") as f: |
|
while True: |
|
data = f.read(1048576) |
|
if not data: |
|
break |
|
sha1.update(data) |
|
|
|
return sha1.hexdigest() == sha1_hash |
|
|