|
from pathlib import Path |
|
|
|
import pytest |
|
|
|
from invokeai.backend.model_manager.search import ModelSearch |
|
|
|
|
|
@pytest.fixture |
|
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]: |
|
search = ModelSearch() |
|
return search, tmp_path |
|
|
|
|
|
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]): |
|
search, tmp_path = model_search |
|
on_search_started_called_with: Path | None = None |
|
|
|
def on_search_started_callback(path: Path) -> None: |
|
nonlocal on_search_started_called_with |
|
on_search_started_called_with = path |
|
|
|
search.on_search_started = on_search_started_callback |
|
search.search(tmp_path) |
|
|
|
assert on_search_started_called_with == tmp_path |
|
|
|
|
|
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]): |
|
search, tmp_path = model_search |
|
on_search_completed_called_with: set[Path] | None = None |
|
file1 = tmp_path / "file1.ckpt" |
|
with open(file1, "w") as f: |
|
f.write("") |
|
|
|
def on_search_completed_callback(models: set[Path]) -> None: |
|
nonlocal on_search_completed_called_with |
|
on_search_completed_called_with = models |
|
|
|
search.on_search_completed = on_search_completed_callback |
|
expected = {file1} |
|
found = search.search(tmp_path) |
|
|
|
assert found == expected |
|
assert on_search_completed_called_with == expected |
|
|
|
|
|
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]): |
|
search, tmp_path = model_search |
|
on_model_found_called_with: set[Path] = set() |
|
|
|
file1 = tmp_path / "file1.ckpt" |
|
file2 = tmp_path / "file2.ckpt" |
|
file3 = tmp_path / "subfolder" / "file3.ckpt" |
|
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt" |
|
file5 = tmp_path / "not_a_model_file.txt" |
|
|
|
file4.parent.mkdir(parents=True) |
|
for file in [file1, file2, file3, file4, file5]: |
|
with open(file, "w") as f: |
|
f.write("") |
|
|
|
def on_model_found_callback(path: Path) -> bool: |
|
on_model_found_called_with.add(path) |
|
return True |
|
|
|
search.on_model_found = on_model_found_callback |
|
|
|
expected = {file1, file2, file3, file4} |
|
found = search.search(tmp_path) |
|
|
|
assert on_model_found_called_with == expected |
|
assert found == expected |
|
assert search.stats.models_found == 4 |
|
assert search.stats.models_filtered == 4 |
|
|
|
|
|
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]): |
|
search, tmp_path = model_search |
|
on_model_found_called_with: set[Path] = set() |
|
|
|
file1 = tmp_path / "file1.ckpt" |
|
file2 = tmp_path / "file2.ckpt" |
|
|
|
for file in [file1, file2]: |
|
with open(file, "w") as f: |
|
f.write("") |
|
|
|
def on_model_found_callback(path: Path) -> bool: |
|
if path == file2: |
|
return False |
|
on_model_found_called_with.add(path) |
|
return True |
|
|
|
search.on_model_found = on_model_found_callback |
|
|
|
expected = {file1} |
|
found = search.search(tmp_path) |
|
|
|
assert on_model_found_called_with == expected |
|
assert found == expected |
|
assert search.stats.models_filtered == 1 |
|
assert search.stats.models_found == 2 |
|
|
|
|
|
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]): |
|
search, tmp_path = model_search |
|
on_model_found_called_with: set[Path] = set() |
|
|
|
diffusers_dir = tmp_path / "diffusers_dir" |
|
diffusers_dir_entry_point = diffusers_dir / "model_index.json" |
|
diffusers_dir.mkdir() |
|
with open(diffusers_dir_entry_point, "w") as f: |
|
f.write("") |
|
|
|
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir" |
|
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json" |
|
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" |
|
nested_diffusers_dir.mkdir(parents=True) |
|
with open(nested_diffusers_dir_entry_point, "w") as f: |
|
f.write("") |
|
with open(nested_diffusers_dir_ignore_me_file, "w") as f: |
|
f.write("") |
|
|
|
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir" |
|
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json" |
|
not_a_diffusers_dir.mkdir() |
|
with open(not_a_diffusers_dir_entry_point, "w") as f: |
|
f.write("") |
|
|
|
def on_model_found_callback(path: Path) -> bool: |
|
on_model_found_called_with.add(path) |
|
return True |
|
|
|
search.on_model_found = on_model_found_callback |
|
|
|
expected = {diffusers_dir, nested_diffusers_dir} |
|
found = search.search(tmp_path) |
|
|
|
assert found == expected |
|
assert on_model_found_called_with == expected |
|
assert search.stats.models_found == 2 |
|
assert search.stats.models_filtered == 2 |
|
|