File size: 6,832 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# pyright: reportPrivateUsage=false
from contextlib import suppress

from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from tests.test_nodes import PromptTestInvocation


def test_invocation_cache_memory_max_cache_size():
    cache = MemoryInvocationCache()
    assert cache._max_cache_size == 0
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    cache.save(1, output_1)
    assert cache.get(1) is None
    assert cache._hits == 0
    assert cache._misses == 0  # TODO: when cache size is zero, should we consider it a miss?
    assert len(cache._cache) == 0


def test_invocation_cache_memory_creates_deterministic_keys():
    hash1 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
    hash2 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="foo"))
    hash3 = MemoryInvocationCache.create_key(PromptTestInvocation(prompt="bar"))

    assert hash1 == hash2
    assert hash1 != hash3


def test_invocation_cache_memory_adds_invocation():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=5)
    cache.save(1, output_1)
    cache.save(2, output_2)
    assert cache.get(1) == output_1
    assert cache.get(2) == output_2


def test_invocation_cache_memory_tracks_hits():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=5)
    cache.save(1, output_1)
    cache.get(1)  # hit
    cache.get(1)  # hit
    cache.get(1)  # hit
    cache.get(2)  # miss
    cache.get(2)  # miss
    assert cache._hits == 3
    assert cache._misses == 2


def test_invocation_cache_memory_is_lru():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=2)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    assert cache.get(1) is None
    assert cache.get(2) == output_2
    assert cache.get(3) == output_3
    assert len(cache._cache) == 2
    assert list(cache._cache.keys()) == [2, 3]
    cache.get(2)
    assert list(cache._cache.keys()) == [3, 2]


def test_invocation_cache_memory_disables_and_enables():
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    cache = MemoryInvocationCache(max_cache_size=2)
    cache.save(1, output_1)
    cache.disable()
    assert cache.get(1) is None
    cache.save(2, output_2)
    assert cache.get(2) is None
    assert len(cache._cache) == 1
    assert cache._hits == 0
    assert cache._misses == 0
    cache.enable()
    cache.save(2, output_2)
    assert cache.get(2) is output_2
    assert len(cache._cache) == 2
    assert cache._hits == 1
    assert cache._misses == 0


def test_invocation_cache_memory_deletes_by_match():
    # The _delete_by_match method attempts to log but the logger is not set up in the test environment
    with suppress(AttributeError):
        cache = MemoryInvocationCache(max_cache_size=5)
        output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
        output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
        output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
        cache.save(1, output_1)
        cache.save(2, output_2)
        cache.save(3, output_3)
        cache._delete_by_match("bar")
        assert cache.get(1) == output_1
        assert cache.get(2) is None
        assert cache.get(3) == output_3
        assert len(cache._cache) == 2
        assert list(cache._cache.keys()) == [1, 3]
        cache._delete_by_match("foo")
        assert cache.get(1) is None
        assert cache.get(2) is None
        assert cache.get(3) == output_3
        assert len(cache._cache) == 1
        assert list(cache._cache.keys()) == [3]
        cache._delete_by_match("baz")
        assert cache.get(1) is None
        assert cache.get(2) is None
        assert cache.get(3) is None
        assert len(cache._cache) == 0
        assert list(cache._cache.keys()) == []
        # shouldn't raise on empty cache
        cache._delete_by_match("foo")


def test_invocation_cache_memory_clears():
    cache = MemoryInvocationCache(max_cache_size=5)
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    cache.get(1)
    cache.get(2)
    cache.get(3)
    cache.get("foo")  # miss
    cache.get("bar")  # miss
    cache.clear()
    assert len(cache._cache) == 0
    assert cache._hits == 0
    assert cache._misses == 0
    assert cache._misses == 0
    assert cache.get(1) is None
    assert cache.get(2) is None
    assert cache.get(3) is None


def test_invocation_cache_memory_status():
    cache = MemoryInvocationCache(max_cache_size=5)
    output_1 = ImageOutput(image=ImageField(image_name="foo"), width=512, height=512)
    output_2 = ImageOutput(image=ImageField(image_name="bar"), width=512, height=512)
    output_3 = ImageOutput(image=ImageField(image_name="baz"), width=512, height=512)
    cache.save(1, output_1)
    cache.save(2, output_2)
    cache.save(3, output_3)
    cache.get(1)
    cache.get(2)
    cache.get(3)
    cache.get("foo")  # miss
    cache.get("bar")  # miss
    status = cache.get_status()
    assert status.hits == 3
    assert status.misses == 2
    assert status.enabled
    assert status.size == 3
    assert status.max_size == 5
    cache.disable()
    status = cache.get_status()
    assert not status.enabled
    cache.enable()
    status = cache.get_status()
    assert status.enabled
    cache.clear()
    status = cache.get_status()
    assert status.size == 0
    assert status.hits == 0
    assert status.misses == 0
    assert status.enabled
    assert status.max_size == 5
    cache._max_cache_size = 0  # cache should be disabled when max_cache_size is zero
    status = cache.get_status()
    assert not status.enabled
    assert status.size == 0
    assert status.hits == 0
    assert status.misses == 0
    assert status.max_size == 0