|
|
|
|
|
import os |
|
import shutil |
|
from pathlib import Path |
|
|
|
import pytest |
|
from requests.sessions import Session |
|
from requests_testadapter import TestAdapter, TestSession |
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig |
|
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase |
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase |
|
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase |
|
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase |
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL |
|
from invokeai.backend.model_manager.config import ( |
|
BaseModelType, |
|
LoRADiffusersConfig, |
|
MainCheckpointConfig, |
|
MainDiffusersConfig, |
|
ModelFormat, |
|
ModelSourceType, |
|
ModelType, |
|
ModelVariantType, |
|
VAEDiffusersConfig, |
|
) |
|
from invokeai.backend.model_manager.load import ModelCache |
|
from invokeai.backend.util.logging import InvokeAILogger |
|
from tests.backend.model_manager.model_metadata.metadata_examples import ( |
|
HFTestLoraMetadata, |
|
RepoCivitaiModelMetadata1, |
|
RepoCivitaiVersionMetadata1, |
|
RepoHFMetadata1, |
|
RepoHFMetadata1_nofp16, |
|
RepoHFModelJson1, |
|
) |
|
from tests.fixtures.sqlite_database import create_mock_sqlite_database |
|
from tests.test_nodes import TestEventService |
|
|
|
|
|
|
|
@pytest.fixture |
|
def mm2_root_dir(tmp_path_factory) -> Path: |
|
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root" |
|
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root" |
|
shutil.copytree(root_template, temp_dir) |
|
return temp_dir |
|
|
|
|
|
@pytest.fixture |
|
def mm2_model_files(tmp_path_factory) -> Path: |
|
root_template = Path(__file__).resolve().parent / "data" / "test_files" |
|
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files" |
|
shutil.copytree(root_template, temp_dir) |
|
return temp_dir |
|
|
|
|
|
@pytest.fixture |
|
def embedding_file(mm2_model_files: Path) -> Path: |
|
return mm2_model_files / "test_embedding.safetensors" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
def diffusers_dir(mm2_model_files: Path) -> Path: |
|
return mm2_model_files / "test-diffusers-main" |
|
|
|
|
|
@pytest.fixture |
|
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: |
|
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info") |
|
app_config._root = mm2_root_dir |
|
return app_config |
|
|
|
|
|
@pytest.fixture |
|
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase: |
|
download_queue = DownloadQueueService(requests_session=mm2_session) |
|
download_queue.start() |
|
yield download_queue |
|
download_queue.stop() |
|
|
|
|
|
@pytest.fixture |
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: |
|
ram_cache = ModelCache( |
|
logger=InvokeAILogger.get_logger(), |
|
max_cache_size=mm2_app_config.ram, |
|
max_vram_cache_size=mm2_app_config.vram, |
|
) |
|
return ModelLoadService( |
|
app_config=mm2_app_config, |
|
ram_cache=ram_cache, |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def mm2_installer( |
|
mm2_app_config: InvokeAIAppConfig, |
|
mm2_download_queue: DownloadQueueServiceBase, |
|
mm2_session: Session, |
|
) -> ModelInstallServiceBase: |
|
logger = InvokeAILogger.get_logger() |
|
db = create_mock_sqlite_database(mm2_app_config, logger) |
|
events = TestEventService() |
|
store = ModelRecordServiceSQL(db, logger) |
|
|
|
installer = ModelInstallService( |
|
app_config=mm2_app_config, |
|
record_store=store, |
|
download_queue=mm2_download_queue, |
|
event_bus=events, |
|
session=mm2_session, |
|
) |
|
installer.start() |
|
yield installer |
|
installer.stop() |
|
|
|
|
|
@pytest.fixture |
|
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: |
|
logger = InvokeAILogger.get_logger(config=mm2_app_config) |
|
db = create_mock_sqlite_database(mm2_app_config, logger) |
|
store = ModelRecordServiceSQL(db, logger) |
|
|
|
config1 = VAEDiffusersConfig( |
|
key="test_config_1", |
|
path="/tmp/foo1", |
|
format=ModelFormat.Diffusers, |
|
name="test2", |
|
base=BaseModelType.StableDiffusion2, |
|
type=ModelType.VAE, |
|
hash="111222333444", |
|
source="stabilityai/sdxl-vae", |
|
source_type=ModelSourceType.HFRepoID, |
|
) |
|
config2 = MainCheckpointConfig( |
|
key="test_config_2", |
|
path="/tmp/foo2.ckpt", |
|
name="model1", |
|
format=ModelFormat.Checkpoint, |
|
base=BaseModelType.StableDiffusion1, |
|
type=ModelType.Main, |
|
config_path="/tmp/foo.yaml", |
|
variant=ModelVariantType.Normal, |
|
hash="111222333444", |
|
source="https://civitai.com/models/206883/split", |
|
source_type=ModelSourceType.Url, |
|
) |
|
config3 = MainDiffusersConfig( |
|
key="test_config_3", |
|
path="/tmp/foo3", |
|
format=ModelFormat.Diffusers, |
|
name="test3", |
|
base=BaseModelType.StableDiffusionXL, |
|
type=ModelType.Main, |
|
hash="111222333444", |
|
source="author3/model3", |
|
description="This is test 3", |
|
source_type=ModelSourceType.HFRepoID, |
|
) |
|
config4 = LoRADiffusersConfig( |
|
key="test_config_4", |
|
path="/tmp/foo4", |
|
format=ModelFormat.Diffusers, |
|
name="test4", |
|
base=BaseModelType.StableDiffusionXL, |
|
type=ModelType.LoRA, |
|
hash="111222333444", |
|
source="author4/model4", |
|
source_type=ModelSourceType.HFRepoID, |
|
) |
|
config5 = LoRADiffusersConfig( |
|
key="test_config_5", |
|
path="/tmp/foo5", |
|
format=ModelFormat.Diffusers, |
|
name="test5", |
|
base=BaseModelType.StableDiffusion1, |
|
type=ModelType.LoRA, |
|
hash="111222333444", |
|
source="author4/model5", |
|
source_type=ModelSourceType.HFRepoID, |
|
) |
|
store.add_model(config1) |
|
store.add_model(config2) |
|
store.add_model(config3) |
|
store.add_model(config4) |
|
store.add_model(config5) |
|
return store |
|
|
|
|
|
@pytest.fixture |
|
def mm2_model_manager( |
|
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase |
|
) -> ModelManagerServiceBase: |
|
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader) |
|
|
|
|
|
@pytest.fixture |
|
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: |
|
"""This fixtures defines a series of mock URLs for testing download and installation.""" |
|
sess: Session = TestSession() |
|
sess.mount( |
|
"https://test.com/missing_model.safetensors", |
|
TestAdapter( |
|
b"missing", |
|
status=404, |
|
), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/api/models/stabilityai/sdxl-turbo", |
|
TestAdapter( |
|
RepoHFMetadata1, |
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)}, |
|
), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16", |
|
TestAdapter( |
|
RepoHFMetadata1_nofp16, |
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)}, |
|
), |
|
) |
|
sess.mount( |
|
"https://civitai.com/api/v1/model-versions/242807", |
|
TestAdapter( |
|
RepoCivitaiVersionMetadata1, |
|
headers={ |
|
"Content-Length": len(RepoCivitaiVersionMetadata1), |
|
}, |
|
), |
|
) |
|
sess.mount( |
|
"https://civitai.com/api/v1/models/215485", |
|
TestAdapter( |
|
RepoCivitaiModelMetadata1, |
|
headers={ |
|
"Content-Length": len(RepoCivitaiModelMetadata1), |
|
}, |
|
), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json", |
|
TestAdapter( |
|
RepoHFModelJson1, |
|
headers={ |
|
"Content-Length": len(RepoHFModelJson1), |
|
}, |
|
), |
|
) |
|
with open(embedding_file, "rb") as f: |
|
data = f.read() |
|
sess.mount( |
|
"https://www.test.foo/download/test_embedding.safetensors", |
|
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/api/models/stabilityai/sdxl-turbo", |
|
TestAdapter( |
|
RepoHFMetadata1, |
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)}, |
|
), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True", |
|
TestAdapter( |
|
HFTestLoraMetadata, |
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)}, |
|
), |
|
) |
|
sess.mount( |
|
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors", |
|
TestAdapter( |
|
data, |
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)}, |
|
), |
|
) |
|
for root, _, files in os.walk(diffusers_dir): |
|
for name in files: |
|
path = Path(root, name) |
|
url_base = path.relative_to(diffusers_dir).as_posix() |
|
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}" |
|
with open(path, "rb") as f: |
|
data = f.read() |
|
sess.mount( |
|
url, |
|
TestAdapter( |
|
data, |
|
headers={ |
|
"Content-Type": "application/json; charset=utf-8", |
|
"Content-Length": len(data), |
|
}, |
|
), |
|
) |
|
|
|
for i in ["12345", "9999", "54321"]: |
|
content = ( |
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) |
|
) |
|
sess.mount( |
|
f"http://www.civitai.com/models/{i}", |
|
TestAdapter( |
|
content, |
|
headers={ |
|
"Content-Length": len(content), |
|
"Content-Disposition": f'filename="mock{i}.safetensors"', |
|
}, |
|
), |
|
) |
|
|
|
sess.mount( |
|
"http://www.huggingface.co/foo.txt", |
|
TestAdapter( |
|
content, |
|
headers={ |
|
"Content-Length": len(content), |
|
"Content-Disposition": 'filename="foo.safetensors"', |
|
}, |
|
), |
|
) |
|
|
|
|
|
|
|
sess.mount( |
|
"http://www.civitai.com/models/missing", |
|
TestAdapter( |
|
b"Missing content length", |
|
headers={ |
|
"Content-Disposition": 'filename="missing.txt"', |
|
}, |
|
), |
|
) |
|
|
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) |
|
|
|
return sess |
|
|