test_model / scripts /download.py
khoicrtp's picture
init
12001a9
raw
history blame
1.26 kB
import os
from typing import Optional
from urllib.request import urlretrieve
files = {
"original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py",
"original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py",
}
def download_original(wd: str) -> None:
for file, url in files.items():
filepath = os.path.join(wd, file)
if not os.path.isfile(filepath):
print(f"Downloading original implementation to {filepath!r}")
urlretrieve(url=url, filename=file)
print("Done")
else:
print("Original implementation found. Skipping download.")
def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None:
if repo_id is None:
raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.")
from huggingface_hub import snapshot_download
snapshot_download(repo_id, local_dir=local_dir)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(download_from_hub)