import urllib3

from huggingface_hub import snapshot_download


def change_default_timeout(new_timeout: int) -> None:
    """
    Changes the default timeout for downloading repositories from the Hugging Face Hub.
    Prevents the following errors:
    urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
    Read timed out. (read timeout=10)
    """
    urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout


def download_pytorch_model(name: str) -> None:
    """
    Downloads a pytorch model and all the small files from the model's repository.
    Other model formats (tensorflow, tflite, safetensors, msgpack and ot) are not downloaded.
    """
    number_of_seconds_in_a_day: int = 86_400
    change_default_timeout(number_of_seconds_in_a_day)
    snapshot_download(
        repo_id=name,
        etag_timeout=number_of_seconds_in_a_day,
        resume_download=True,
        repo_type="model",
        library_name="pt",
        # h5, tflite, safetensors, msgpack and ot models files are not needed
        ignore_patterns=[
            "*.h5",
            "*.tflite",
            "*.safetensors",
            "*.msgpack",
            "*.ot",
            "*.md"
        ],
    )


if __name__ == "__main__":
    download_pytorch_model("facebook/opt-125m")