File size: 4,623 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# pyright:reportPrivateUsage=false

from pathlib import Path
from typing import Iterable

import pytest
from blake3 import blake3

from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash

test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
    ("md5", "md5:a0cd925fc063f98dbf029eee315060c3"),
    ("sha1", "sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c"),
    ("sha256", "sha256:6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
    (
        "sha512",
        "sha512:c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
    ),
    ("blake3_multi", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
    ("blake3_single", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
]


@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str):
    file = Path(tmp_path / "test")
    file.write_text("model data")
    hash_ = ModelHash(algorithm).hash(file)
    assert hash_ == expected_hash


@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3_multi", "blake3_single"])
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
    model_hash = ModelHash(algorithm)
    files = [Path(tmp_path, f"{i}.bin") for i in range(5)]

    for f in files:
        f.write_text("data")

    hash_ = model_hash.hash(tmp_path)

    # Manual implementation of composite hash - always uses BLAKE3
    component_hashes: list[str] = []
    for f in sorted(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)):
        component_hashes.append(model_hash._hash_file(f))

    composite_hasher = blake3()
    for h in component_hashes:
        composite_hasher.update(h.encode("utf-8"))

    assert hash_ == ModelHash._get_prefix(algorithm) + composite_hasher.hexdigest()


@pytest.mark.parametrize(
    "algorithm,expected_prefix",
    [
        ("md5", "md5:"),
        ("sha1", "sha1:"),
        ("sha256", "sha256:"),
        ("sha512", "sha512:"),
        ("blake3_multi", "blake3:"),
        ("blake3_single", "blake3:"),
    ],
)
def test_model_hash_gets_prefix(algorithm: HASHING_ALGORITHMS, expected_prefix: str):
    assert ModelHash._get_prefix(algorithm) == expected_prefix


def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):
    model_hash = ModelHash("blake3_multi")
    model_hash_simple = ModelHash("blake3_single")

    file = tmp_path / "test.bin"
    file.write_text("model data")

    assert model_hash.hash(file) == model_hash_simple.hash(file)


def test_model_hash_random_algorithm(tmp_path: Path):
    model_hash = ModelHash("random")
    file = tmp_path / "test.bin"
    file.write_text("model data")

    assert model_hash.hash(file) != model_hash.hash(file)


def test_model_hash_raises_error_on_invalid_algorithm():
    with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
        ModelHash("invalid_algorithm")  # pyright: ignore [reportArgumentType]


def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
    return {str(p) for p in paths}


def test_model_hash_filters_out_non_model_files(tmp_path: Path):
    model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}

    for i, f in enumerate(model_files):
        f.write_text(f"data{i}")

    assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
        model_files
    )

    # Add file that should be ignored - hash should not change
    file = tmp_path / "test.icecream"
    file.write_text("data")

    assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
        model_files
    )

    # Add file that should not be ignored - hash should change
    file = tmp_path / "test.bin"
    file.write_text("more data")
    model_files.add(file)

    assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
        model_files
    )


def test_model_hash_uses_custom_filter(tmp_path: Path):
    model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}

    for i, f in enumerate(model_files):
        f.write_text(f"data{i}")

    def file_filter(file_path: str) -> bool:
        return file_path.endswith(".pickme")

    assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}