File size: 9,773 Bytes
4a51346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import pytest
from typing import Generator, List, Callable, Dict, Union
from chromadb.types import Collection, Segment, SegmentScope
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.config import System, Settings
from chromadb.db.system import SysDB
from chromadb.db.base import NotFoundError, UniqueConstraintError
from pytest import FixtureRequest
import uuid


def sqlite() -> Generator[SysDB, None, None]:
    """Fixture generator for sqlite DB"""
    db = SqliteDB(System(Settings(sqlite_database=":memory:", allow_reset=True)))
    db.start()
    yield db
    db.stop()


def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]:
    return [sqlite]


@pytest.fixture(scope="module", params=db_fixtures())
def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]:
    yield next(request.param())


sample_collections = [
    Collection(
        id=uuid.uuid4(),
        name="test_collection_1",
        topic="test_topic_1",
        metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3},
    ),
    Collection(
        id=uuid.uuid4(),
        name="test_collection_2",
        topic="test_topic_2",
        metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3},
    ),
    Collection(
        id=uuid.uuid4(),
        name="test_collection_3",
        topic="test_topic_3",
        metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3},
    ),
]


def test_create_get_delete_collections(sysdb: SysDB) -> None:
    sysdb.reset()

    for collection in sample_collections:
        sysdb.create_collection(collection)

    results = sysdb.get_collections()
    results = sorted(results, key=lambda c: c["name"])

    assert sorted(results, key=lambda c: c["name"]) == sample_collections

    # Duplicate create fails
    with pytest.raises(UniqueConstraintError):
        sysdb.create_collection(sample_collections[0])

    # Find by name
    for collection in sample_collections:
        result = sysdb.get_collections(name=collection["name"])
        assert result == [collection]

    # Find by topic
    for collection in sample_collections:
        result = sysdb.get_collections(topic=collection["topic"])
        assert result == [collection]

    # Find by id
    for collection in sample_collections:
        result = sysdb.get_collections(id=collection["id"])
        assert result == [collection]

    # Find by id and topic (positive case)
    for collection in sample_collections:
        result = sysdb.get_collections(id=collection["id"], topic=collection["topic"])
        assert result == [collection]

    # find by id and topic (negative case)
    for collection in sample_collections:
        result = sysdb.get_collections(id=collection["id"], topic="other_topic")
        assert result == []

    # Delete
    c1 = sample_collections[0]
    sysdb.delete_collection(c1["id"])

    results = sysdb.get_collections()
    assert c1 not in results
    assert len(results) == len(sample_collections) - 1
    assert sorted(results, key=lambda c: c["name"]) == sample_collections[1:]

    by_id_result = sysdb.get_collections(id=c1["id"])
    assert by_id_result == []

    # Duplicate delete throws an exception
    with pytest.raises(NotFoundError):
        sysdb.delete_collection(c1["id"])


def test_update_collections(sysdb: SysDB) -> None:
    metadata: Dict[str, Union[str, int, float]] = {
        "test_str": "str1",
        "test_int": 1,
        "test_float": 1.3,
    }
    coll = Collection(
        id=uuid.uuid4(),
        name="test_collection_1",
        topic="test_topic_1",
        metadata=metadata,
    )

    sysdb.reset()

    sysdb.create_collection(coll)

    # Update name
    coll["name"] = "new_name"
    sysdb.update_collection(coll["id"], name=coll["name"])
    result = sysdb.get_collections(name=coll["name"])
    assert result == [coll]

    # Update topic
    coll["topic"] = "new_topic"
    sysdb.update_collection(coll["id"], topic=coll["topic"])
    result = sysdb.get_collections(topic=coll["topic"])
    assert result == [coll]

    # Add a new metadata key
    metadata["test_str2"] = "str2"
    sysdb.update_collection(coll["id"], metadata={"test_str2": "str2"})
    result = sysdb.get_collections(id=coll["id"])
    assert result == [coll]

    # Update a metadata key
    metadata["test_str"] = "str3"
    sysdb.update_collection(coll["id"], metadata={"test_str": "str3"})
    result = sysdb.get_collections(id=coll["id"])
    assert result == [coll]

    # Delete a metadata key
    del metadata["test_str"]
    sysdb.update_collection(coll["id"], metadata={"test_str": None})
    result = sysdb.get_collections(id=coll["id"])
    assert result == [coll]

    # Delete all metadata keys
    coll["metadata"] = None
    sysdb.update_collection(coll["id"], metadata=None)
    result = sysdb.get_collections(id=coll["id"])
    assert result == [coll]


sample_segments = [
    Segment(
        id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"),
        type="test_type_a",
        scope=SegmentScope.VECTOR,
        topic=None,
        collection=sample_collections[0]["id"],
        metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3},
    ),
    Segment(
        id=uuid.UUID("11111111-d7d7-413b-92e1-731098a6e492"),
        type="test_type_b",
        topic="test_topic_2",
        scope=SegmentScope.VECTOR,
        collection=sample_collections[1]["id"],
        metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3},
    ),
    Segment(
        id=uuid.UUID("22222222-d7d7-413b-92e1-731098a6e492"),
        type="test_type_b",
        topic="test_topic_3",
        scope=SegmentScope.METADATA,
        collection=None,
        metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3},
    ),
]


def test_create_get_delete_segments(sysdb: SysDB) -> None:
    sysdb.reset()

    for collection in sample_collections:
        sysdb.create_collection(collection)

    for segment in sample_segments:
        sysdb.create_segment(segment)

    results = sysdb.get_segments()
    results = sorted(results, key=lambda c: c["id"])

    assert results == sample_segments

    # Duplicate create fails
    with pytest.raises(UniqueConstraintError):
        sysdb.create_segment(sample_segments[0])

    # Find by id
    for segment in sample_segments:
        result = sysdb.get_segments(id=segment["id"])
        assert result == [segment]

    # Find by type
    result = sysdb.get_segments(type="test_type_a")
    assert result == sample_segments[:1]

    result = sysdb.get_segments(type="test_type_b")
    assert result == sample_segments[1:]

    # Find by collection ID
    result = sysdb.get_segments(collection=sample_collections[0]["id"])
    assert result == sample_segments[:1]

    # Find by type and collection ID (positive case)
    result = sysdb.get_segments(
        type="test_type_a", collection=sample_collections[0]["id"]
    )
    assert result == sample_segments[:1]

    # Find by type and collection ID (negative case)
    result = sysdb.get_segments(
        type="test_type_b", collection=sample_collections[0]["id"]
    )
    assert result == []

    # Delete
    s1 = sample_segments[0]
    sysdb.delete_segment(s1["id"])

    results = sysdb.get_segments()
    assert s1 not in results
    assert len(results) == len(sample_segments) - 1
    assert sorted(results, key=lambda c: c["type"]) == sample_segments[1:]

    # Duplicate delete throws an exception
    with pytest.raises(NotFoundError):
        sysdb.delete_segment(s1["id"])


def test_update_segment(sysdb: SysDB) -> None:
    metadata: Dict[str, Union[str, int, float]] = {
        "test_str": "str1",
        "test_int": 1,
        "test_float": 1.3,
    }
    segment = Segment(
        id=uuid.uuid4(),
        type="test_type_a",
        scope=SegmentScope.VECTOR,
        topic="test_topic_a",
        collection=sample_collections[0]["id"],
        metadata=metadata,
    )

    sysdb.reset()
    for c in sample_collections:
        sysdb.create_collection(c)

    sysdb.create_segment(segment)

    # Update topic to new value
    segment["topic"] = "new_topic"
    sysdb.update_segment(segment["id"], topic=segment["topic"])
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Update topic to None
    segment["topic"] = None
    sysdb.update_segment(segment["id"], topic=segment["topic"])
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Update collection to new value
    segment["collection"] = sample_collections[1]["id"]
    sysdb.update_segment(segment["id"], collection=segment["collection"])
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Update collection to None
    segment["collection"] = None
    sysdb.update_segment(segment["id"], collection=segment["collection"])
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Add a new metadata key
    metadata["test_str2"] = "str2"
    sysdb.update_segment(segment["id"], metadata={"test_str2": "str2"})
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Update a metadata key
    metadata["test_str"] = "str3"
    sysdb.update_segment(segment["id"], metadata={"test_str": "str3"})
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Delete a metadata key
    del metadata["test_str"]
    sysdb.update_segment(segment["id"], metadata={"test_str": None})
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]

    # Delete all metadata keys
    segment["metadata"] = None
    sysdb.update_segment(segment["id"], metadata=None)
    result = sysdb.get_segments(id=segment["id"])
    assert result == [segment]