nsthorat commited on
Commit
80dadd4
1 Parent(s): dbe8c69
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +10 -0
  2. .env +40 -0
  3. .env.demo +4 -0
  4. .gitattributes +0 -35
  5. .gitignore +5 -0
  6. Dockerfile +29 -0
  7. LICENSE +201 -0
  8. README.md +6 -9
  9. lilac/.gitignore +1 -0
  10. lilac/__init__.py +33 -0
  11. lilac/auth.py +87 -0
  12. lilac/batch_utils.py +92 -0
  13. lilac/cli.py +39 -0
  14. lilac/concepts/__init__.py +0 -0
  15. lilac/concepts/concept.py +330 -0
  16. lilac/concepts/db_concept.py +520 -0
  17. lilac/config.py +80 -0
  18. lilac/conftest.py +28 -0
  19. lilac/data/__init__.py +9 -0
  20. lilac/data/dataset.py +485 -0
  21. lilac/data/dataset_duckdb.py +1717 -0
  22. lilac/data/dataset_test_utils.py +127 -0
  23. lilac/data/dataset_utils.py +308 -0
  24. lilac/data/duckdb_utils.py +25 -0
  25. lilac/data_loader.py +110 -0
  26. lilac/db_manager.py +42 -0
  27. lilac/embeddings/__init__.py +0 -0
  28. lilac/embeddings/cohere.py +59 -0
  29. lilac/embeddings/default_vector_stores.py +10 -0
  30. lilac/embeddings/embedding.py +110 -0
  31. lilac/embeddings/gte.py +63 -0
  32. lilac/embeddings/openai.py +68 -0
  33. lilac/embeddings/palm.py +62 -0
  34. lilac/embeddings/sbert.py +38 -0
  35. lilac/embeddings/transformer_utils.py +35 -0
  36. lilac/embeddings/vector_store.py +200 -0
  37. lilac/embeddings/vector_store_hnsw.py +106 -0
  38. lilac/embeddings/vector_store_numpy.py +92 -0
  39. lilac/env.py +63 -0
  40. lilac/load.py +214 -0
  41. lilac/make_openapi.py +29 -0
  42. lilac/parquet_writer.py +70 -0
  43. lilac/router_concept.py +209 -0
  44. lilac/router_data_loader.py +80 -0
  45. lilac/router_dataset.py +303 -0
  46. lilac/router_google_login.py +60 -0
  47. lilac/router_signal.py +105 -0
  48. lilac/router_tasks.py +14 -0
  49. lilac/router_utils.py +54 -0
  50. lilac/schema.py +600 -0
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ **/__pycache__
3
+ **/*.pyc
4
+ **/*.pyo
5
+ **/*.pyd
6
+ # Ignore unit tests.
7
+ **/*_test.py
8
+
9
+ # Mac OS.
10
+ .DS_Store
.env ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # To overwrite these variables, create a .env.local file
2
+
3
+ # The path to the directory where the data will be downloaded on machine
4
+ LILAC_DATA_PATH=./data
5
+
6
+ # Set to 1 for duckdb to use views instead of materialized tables (lower memory usage, but slower).
7
+ DUCKDB_USE_VIEWS=0
8
+
9
+ # Set to true to enable read-only mode, disabling the ability to add datasets & compute dataset
10
+ # signals.
11
+ # LILAC_AUTH_ENABLED=true
12
+
13
+ # Variables that can be set in .env.local
14
+ #
15
+ # Get key from https://dashboard.cohere.ai/api-keys
16
+ # COHERE_API_KEY=
17
+
18
+ # GCS_REGION=
19
+ # GCS_ACCESS_KEY=
20
+ # GCS_SECRET_KEY=
21
+
22
+ # Get key from https://platform.openai.com/account/api-keys
23
+ # OPENAI_API_KEY=
24
+ # Get key from https://makersuite.google.com/app/apikey
25
+ # PALM_API_KEY=
26
+
27
+ # HuggingFace demos: machine that uploads to HuggingFace.
28
+
29
+ # For authenticating with HuggingFace to deploy to a Space.
30
+ # HF_USERNAME=
31
+ # The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
32
+ # HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
33
+
34
+ # For Google-login. This is generated from the Google Cloud Console for a web client.
35
+ # See: https://developers.google.com/identity/protocols/oauth2
36
+ GOOGLE_CLIENT_ID='279475920249-i8llm8vbos1vj5m1qocir8narb3r0enu.apps.googleusercontent.com'
37
+ # The client secret of the above client.
38
+ # GOOGLE_CLIENT_SECRET=
39
+ # A random string for oauth sessions.
40
+ # LILAC_OAUTH_SECRET_KEY=
.env.demo ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ LILAC_DATA_PATH='/data'
2
+ HF_HOME='/data/.huggingface'
3
+ TRANSFORMERS_CACHE='/data/.cache'
4
+ XDG_CACHE_HOME='/data/.cache'
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ **/*.pyc
3
+ **/*.pyo
4
+ **/*.pyd
5
+ **/*_test.py
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: When we upgrade to 3.11 we can use a slimmer docker image which comes with gcc.
2
+ FROM python:3.9-bullseye
3
+
4
+ # Allow statements and log messages to immediately appear in the Knative logs
5
+ ENV PYTHONUNBUFFERED True
6
+
7
+ # Set the working directory in the container.
8
+ WORKDIR /server
9
+
10
+ # Install the dependencies. This requires exporting requirements.txt from poetry first, which
11
+ # happens from ./build_docker.sh.
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY .env .
16
+ COPY .env.demo .
17
+ COPY LICENSE .
18
+
19
+ # Copy python files.
20
+ COPY /lilac ./lilac/
21
+
22
+ # Copy the data files. We use glob so docker copy won't fail if the directory doesn't exist.
23
+ COPY /dat[a] ./data/
24
+
25
+ CMD [ \
26
+ "gunicorn", "lilac.server:app", \
27
+ "--bind", "0.0.0.0:5432", \
28
+ "-k", "uvicorn.workers.UvicornWorker" \
29
+ ]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Lilac AI Inc.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,11 +1,8 @@
1
  ---
2
- title: Lilac
3
- emoji: 🌖
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Lilac Blueprint
3
+ emoji: 🌷
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: docker
7
+ app_port: 5432
8
+ ---
 
 
 
lilac/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ web/
lilac/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import metadata
2
+
3
+ from .data import * # noqa: F403
4
+ from .data.dataset_duckdb import DatasetDuckDB
5
+ from .data_loader import create_dataset
6
+ from .db_manager import get_dataset, set_default_dataset_cls
7
+ from .embeddings.default_vector_stores import register_default_vector_stores
8
+ from .server import start_server, stop_server
9
+ from .signals import * # noqa: F403
10
+ from .signals.default_signals import register_default_signals
11
+ from .sources import * # noqa: F403
12
+ from .sources.default_sources import register_default_sources
13
+
14
+ try:
15
+ __version__ = metadata.version('lilacai')
16
+ except metadata.PackageNotFoundError:
17
+ __version__ = ''
18
+
19
+ register_default_sources()
20
+ register_default_signals()
21
+ register_default_vector_stores()
22
+ set_default_dataset_cls(DatasetDuckDB)
23
+
24
+ # Avoids polluting the results of dir(__package__).
25
+ del (metadata, register_default_sources, register_default_signals, set_default_dataset_cls,
26
+ DatasetDuckDB)
27
+
28
+ __all__ = [
29
+ 'start_server',
30
+ 'stop_server',
31
+ 'create_dataset',
32
+ 'get_dataset',
33
+ ]
lilac/auth.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication and ACL configuration."""
2
+
3
+ from typing import Optional
4
+
5
+ from fastapi import Request
6
+ from pydantic import BaseModel, ValidationError
7
+
8
+ from .env import env
9
+
10
+
11
+ class ConceptAuthorizationException(Exception):
12
+ """Authorization exceptions thrown by the concept database."""
13
+ pass
14
+
15
+
16
+ class DatasetUserAccess(BaseModel):
17
+ """User access for datasets."""
18
+ # Whether the user can compute a signal.
19
+ compute_signals: bool
20
+ # Whether the user can delete a dataset.
21
+ delete_dataset: bool
22
+ # Whether the user can delete a signal.
23
+ delete_signals: bool
24
+ # Whether the user can update settings.
25
+ update_settings: bool
26
+
27
+
28
+ class ConceptUserAccess(BaseModel):
29
+ """User access for concepts."""
30
+ # Whether the user can delete any concept (not their own).
31
+ delete_any_concept: bool
32
+
33
+
34
+ class UserAccess(BaseModel):
35
+ """User access."""
36
+ create_dataset: bool
37
+
38
+ # TODO(nsthorat): Make this keyed to each dataset and concept.
39
+ dataset: DatasetUserAccess
40
+ concept: ConceptUserAccess
41
+
42
+
43
+ class UserInfo(BaseModel):
44
+ """User information."""
45
+ id: str
46
+ email: str
47
+ name: str
48
+ given_name: str
49
+ family_name: str
50
+
51
+
52
+ class AuthenticationInfo(BaseModel):
53
+ """Authentication information for the user."""
54
+ user: Optional[UserInfo] = None
55
+ access: UserAccess
56
+ auth_enabled: bool
57
+
58
+
59
+ def get_session_user(request: Request) -> Optional[UserInfo]:
60
+ """Get the user from the session."""
61
+ if not env('LILAC_AUTH_ENABLED'):
62
+ return None
63
+ user_info_dict = request.session.get('user', None)
64
+ if user_info_dict:
65
+ try:
66
+ return UserInfo.parse_obj(user_info_dict)
67
+ except ValidationError:
68
+ return None
69
+ return None
70
+
71
+
72
+ def get_user_access() -> UserAccess:
73
+ """Get the user access."""
74
+ auth_enabled = env('LILAC_AUTH_ENABLED')
75
+ if isinstance(auth_enabled, str):
76
+ auth_enabled = auth_enabled.lower() == 'true'
77
+ if auth_enabled:
78
+ return UserAccess(
79
+ create_dataset=False,
80
+ dataset=DatasetUserAccess(
81
+ compute_signals=False, delete_dataset=False, delete_signals=False, update_settings=False),
82
+ concept=ConceptUserAccess(delete_any_concept=False))
83
+ return UserAccess(
84
+ create_dataset=True,
85
+ dataset=DatasetUserAccess(
86
+ compute_signals=True, delete_dataset=True, delete_signals=True, update_settings=True),
87
+ concept=ConceptUserAccess(delete_any_concept=True))
lilac/batch_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for the python server."""
2
+ import itertools
3
+ from typing import Any, Callable, Generator, Iterable, Iterator, TypeVar, Union, cast
4
+
5
+ from .schema import Item
6
+ from .utils import chunks, is_primitive
7
+
8
+
9
+ def _deep_flatten(input: Union[Iterator, object],
10
+ is_primitive_predicate: Callable[[object], bool]) -> Generator:
11
+ """Flattens a nested iterable."""
12
+ if is_primitive_predicate(input):
13
+ yield input
14
+ elif isinstance(input, dict):
15
+ yield input
16
+ elif is_primitive(input):
17
+ yield input
18
+ else:
19
+ for elem in cast(Iterator, input):
20
+ yield from _deep_flatten(elem, is_primitive_predicate)
21
+
22
+
23
+ def deep_flatten(input: Union[Iterator, Iterable],
24
+ is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterator:
25
+ """Flattens a deeply nested iterator.
26
+
27
+ Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
28
+ what is a primitive.
29
+ """
30
+ return _deep_flatten(input, is_primitive_predicate)
31
+
32
+
33
+ def _deep_unflatten(flat_input: Iterator[list[object]], original_input: Union[Iterable, object],
34
+ is_primitive_predicate: Callable[[object], bool]) -> Union[list, dict]:
35
+ """Unflattens a deeply flattened iterable according to the original iterable's structure."""
36
+ if is_primitive_predicate(original_input):
37
+ return next(flat_input)
38
+ else:
39
+ values: Iterable
40
+ if isinstance(original_input, dict):
41
+ values = original_input.values()
42
+ else:
43
+ values = cast(Iterable, original_input)
44
+ return [_deep_unflatten(flat_input, orig_elem, is_primitive_predicate) for orig_elem in values]
45
+
46
+
47
+ def deep_unflatten(flat_input: Union[Iterable, Iterator],
48
+ original_input: Union[Iterable, object],
49
+ is_primitive_predicate: Callable[[object], bool] = is_primitive) -> list:
50
+ """Unflattens a deeply flattened iterable according to the original iterable's structure."""
51
+ return cast(list, _deep_unflatten(iter(flat_input), original_input, is_primitive_predicate))
52
+
53
+
54
+ TFlatten = TypeVar('TFlatten')
55
+
56
+
57
+ def flatten(inputs: Iterable[Iterable[TFlatten]]) -> Iterator[TFlatten]:
58
+ """Flattens a nested iterator.
59
+
60
+ Only supports flattening one level deep.
61
+ """
62
+ for input in inputs:
63
+ yield from input
64
+
65
+
66
+ TUnflatten = TypeVar('TUnflatten')
67
+
68
+
69
+ def unflatten(flat_inputs: Union[Iterable[TUnflatten], Iterator[TUnflatten]],
70
+ original_inputs: Iterable[Iterable[Any]]) -> Iterator[list[TUnflatten]]:
71
+ """Unflattens a flattened iterable according to the original iterable's structure."""
72
+ flat_inputs_iter = iter(flat_inputs)
73
+ for original_input in original_inputs:
74
+ yield [next(flat_inputs_iter) for _ in original_input]
75
+
76
+
77
+ TFlatBatchedInput = TypeVar('TFlatBatchedInput')
78
+ TFlatBatchedOutput = TypeVar('TFlatBatchedOutput')
79
+
80
+
81
+ def flat_batched_compute(input: Iterable[Iterable[TFlatBatchedInput]],
82
+ f: Callable[[list[TFlatBatchedInput]], Iterable[TFlatBatchedOutput]],
83
+ batch_size: int) -> Iterable[Iterable[TFlatBatchedOutput]]:
84
+ """Flatten the input, batched call f, and return the output unflattened."""
85
+ # Tee the input so we can use it twice for the input and output shapes.
86
+ input_1, input_2 = itertools.tee(input, 2)
87
+ batches = chunks(flatten(input_1), batch_size)
88
+ batched_outputs = flatten((f(batch) for batch in batches))
89
+ return unflatten(batched_outputs, input_2)
90
+
91
+
92
+ TBatchSpanVectorOutput = TypeVar('TBatchSpanVectorOutput', bound=Item)
lilac/cli.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Lilac CLI."""
2
+
3
+ import click
4
+
5
+ from . import __version__
6
+ from .load import load_command as load
7
+ from .server import start_server
8
+
9
+
10
+ @click.command()
11
+ @click.option(
12
+ '--host',
13
+ help='The host address where the web server will listen to.',
14
+ default='0.0.0.0',
15
+ type=str)
16
+ @click.option('--port', help='The port number of the web-server', type=int, default=5432)
17
+ def start(host: str, port: int) -> None:
18
+ """Starts the Lilac web server."""
19
+ start_server(host=host, port=port, open=True)
20
+
21
+
22
+ @click.command()
23
+ def version() -> None:
24
+ """Prints the version of Lilac."""
25
+ print(__version__)
26
+
27
+
28
+ @click.group()
29
+ def cli() -> None:
30
+ """Lilac CLI."""
31
+ pass
32
+
33
+
34
+ cli.add_command(start)
35
+ cli.add_command(version)
36
+ cli.add_command(load)
37
+
38
+ if __name__ == '__main__':
39
+ cli()
lilac/concepts/__init__.py ADDED
File without changes
lilac/concepts/concept.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Defines the concept and the concept models."""
2
+ import dataclasses
3
+ from enum import Enum
4
+ from typing import Callable, Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from pydantic import BaseModel, validator
9
+ from scipy.interpolate import interp1d
10
+ from sklearn.base import clone
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.metrics import precision_recall_curve, roc_auc_score
13
+ from sklearn.model_selection import KFold
14
+
15
+ from ..embeddings.embedding import get_embed_fn
16
+ from ..schema import SignalInputType
17
+ from ..signals.signal import TextEmbeddingSignal, get_signal_cls
18
+ from ..utils import DebugTimer
19
+
20
+ LOCAL_CONCEPT_NAMESPACE = 'local'
21
+
22
+ # The maximum number of cross-validation models to train.
23
+ MAX_NUM_CROSS_VAL_MODELS = 15
24
+ # The β weight to use for the F-beta score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html
25
+ # β = 0.5 means we value precision 2x as much as recall.
26
+ # β = 2 means we value recall 2x as much as precision.
27
+ F_BETA_WEIGHT = 0.5
28
+
29
+
30
+ class ExampleOrigin(BaseModel):
31
+ """The origin of an example."""
32
+ # The namespace that holds the dataset.
33
+ dataset_namespace: str
34
+
35
+ # The name of the dataset.
36
+ dataset_name: str
37
+
38
+ # The id of row in the dataset that the example was added from.
39
+ dataset_row_id: str
40
+
41
+
42
+ DraftId = Union[Literal['main'], str]
43
+ DRAFT_MAIN = 'main'
44
+
45
+
46
+ class ExampleIn(BaseModel):
47
+ """An example in a concept without the id (used for adding new examples)."""
48
+ label: bool
49
+ text: Optional[str] = None
50
+ img: Optional[bytes] = None
51
+ origin: Optional[ExampleOrigin] = None
52
+ # The name of the draft to put the example in. If None, puts it in the main draft.
53
+ draft: Optional[DraftId] = DRAFT_MAIN
54
+
55
+ @validator('text')
56
+ def parse_text(cls, text: str) -> str:
57
+ """Fixes surrogate errors in text: https://github.com/ijl/orjson/blob/master/README.md#str ."""
58
+ return text.encode('utf-8', 'replace').decode('utf-8')
59
+
60
+
61
+ class Example(ExampleIn):
62
+ """A single example in a concept used for training a concept model."""
63
+ id: str
64
+
65
+
66
+ class Concept(BaseModel):
67
+ """A concept is a collection of examples."""
68
+ # The namespace of the concept.
69
+ namespace: str
70
+ # The name of the concept.
71
+ concept_name: str
72
+ # The type of the data format that this concept represents.
73
+ type: SignalInputType
74
+ data: dict[str, Example]
75
+ version: int = 0
76
+
77
+ description: Optional[str] = None
78
+
79
+ def drafts(self) -> list[DraftId]:
80
+ """Gets all the drafts for the concept."""
81
+ drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
82
+ for example in self.data.values():
83
+ if example.draft:
84
+ drafts.add(example.draft)
85
+ return list(sorted(drafts))
86
+
87
+
88
+ class OverallScore(str, Enum):
89
+ """Enum holding the overall score."""
90
+ NOT_GOOD = 'not_good'
91
+ OK = 'ok'
92
+ GOOD = 'good'
93
+ VERY_GOOD = 'very_good'
94
+ GREAT = 'great'
95
+
96
+
97
+ def _get_overall_score(f1_score: float) -> OverallScore:
98
+ if f1_score < 0.5:
99
+ return OverallScore.NOT_GOOD
100
+ if f1_score < 0.8:
101
+ return OverallScore.OK
102
+ if f1_score < 0.9:
103
+ return OverallScore.GOOD
104
+ if f1_score < 0.95:
105
+ return OverallScore.VERY_GOOD
106
+ return OverallScore.GREAT
107
+
108
+
109
+ class ConceptMetrics(BaseModel):
110
+ """Metrics for a concept."""
111
+ # The average F1 score for the concept computed using cross validation.
112
+ f1: float
113
+ precision: float
114
+ recall: float
115
+ roc_auc: float
116
+ overall: OverallScore
117
+
118
+
119
+ @dataclasses.dataclass
120
+ class LogisticEmbeddingModel:
121
+ """A model that uses logistic regression with embeddings."""
122
+
123
+ _metrics: Optional[ConceptMetrics] = None
124
+ _threshold: float = 0.5
125
+
126
+ def __post_init__(self) -> None:
127
+ # See `notebooks/Toxicity.ipynb` for an example of training a concept model.
128
+ self._model = LogisticRegression(
129
+ class_weight='balanced', C=30, tol=1e-5, warm_start=True, max_iter=5_000, n_jobs=-1)
130
+
131
+ def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
132
+ """Get the scores for the provided embeddings."""
133
+ y_probs = self._model.predict_proba(embeddings)[:, 1]
134
+ # Map [0, threshold, 1] to [0, 0.5, 1].
135
+ interpolate_fn = interp1d([0, self._threshold, 1], [0, 0.4999, 1])
136
+ return interpolate_fn(y_probs)
137
+
138
+ def _setup_training(self, X_train: np.ndarray,
139
+ labels: Union[list[bool], np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
140
+ y_train = np.array(labels)
141
+ # Shuffle the data in unison.
142
+ p = np.random.permutation(len(X_train))
143
+ X_train = X_train[p]
144
+ y_train = y_train[p]
145
+ return X_train, y_train
146
+
147
+ def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None:
148
+ """Fit the model to the provided embeddings and labels."""
149
+ label_set = set(labels)
150
+ if len(label_set) < 2:
151
+ dim = embeddings.shape[1]
152
+ random_vector = np.random.randn(dim).astype(np.float32)
153
+ random_vector /= np.linalg.norm(random_vector)
154
+ embeddings = np.vstack([embeddings, random_vector])
155
+ labels.append(False if True in label_set else True)
156
+
157
+ if len(labels) != len(embeddings):
158
+ raise ValueError(
159
+ f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
160
+ X_train, y_train = self._setup_training(embeddings, labels)
161
+ self._model.fit(X_train, y_train)
162
+ self._metrics, self._threshold = self._compute_metrics(embeddings, labels)
163
+
164
+ def _compute_metrics(self, embeddings: np.ndarray,
165
+ labels: list[bool]) -> tuple[Optional[ConceptMetrics], float]:
166
+ """Return the concept metrics."""
167
+ labels_np = np.array(labels)
168
+ n_splits = min(len(labels_np), MAX_NUM_CROSS_VAL_MODELS)
169
+ fold = KFold(n_splits, shuffle=True, random_state=42)
170
+
171
+ def _fit_and_score(model: LogisticRegression, X_train: np.ndarray, y_train: np.ndarray,
172
+ X_test: np.ndarray, y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
173
+ if len(set(y_train)) < 2:
174
+ return np.array([]), np.array([])
175
+ model.fit(X_train, y_train)
176
+ y_pred = model.predict_proba(X_test)[:, 1]
177
+ return y_test, y_pred
178
+
179
+ # Compute the metrics for each validation fold in parallel.
180
+ jobs: list[Callable] = []
181
+ for (train_index, test_index) in fold.split(embeddings):
182
+ X_train, y_train = embeddings[train_index], labels_np[train_index]
183
+ X_train, y_train = self._setup_training(X_train, y_train)
184
+ X_test, y_test = embeddings[test_index], labels_np[test_index]
185
+ model = clone(self._model)
186
+ jobs.append(delayed(_fit_and_score)(model, X_train, y_train, X_test, y_test))
187
+ results = Parallel(n_jobs=-1)(jobs)
188
+
189
+ y_test = np.concatenate([y_test for y_test, _ in results], axis=0)
190
+ y_pred = np.concatenate([y_pred for _, y_pred in results], axis=0)
191
+ if len(set(y_test)) < 2:
192
+ return None, 0.5
193
+ roc_auc_val = roc_auc_score(y_test, y_pred)
194
+ precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
195
+ numerator = (1 + F_BETA_WEIGHT**2) * precision * recall
196
+ denom = (F_BETA_WEIGHT**2 * precision) + recall
197
+ f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
198
+ max_f1: float = np.max(f1_scores)
199
+ max_f1_index = np.argmax(f1_scores)
200
+ max_f1_thresh: float = thresholds[max_f1_index]
201
+ max_f1_prec: float = precision[max_f1_index]
202
+ max_f1_recall: float = recall[max_f1_index]
203
+ metrics = ConceptMetrics(
204
+ f1=max_f1,
205
+ precision=max_f1_prec,
206
+ recall=max_f1_recall,
207
+ roc_auc=float(roc_auc_val),
208
+ overall=_get_overall_score(max_f1))
209
+ return metrics, max_f1_thresh
210
+
211
+
212
+ def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
213
+ """Get the examples in the provided draft by overriding the main draft."""
214
+ draft_examples: dict[str, dict[str, Example]] = {}
215
+ for id, example in concept.data.items():
216
+ draft_examples.setdefault(example.draft or DRAFT_MAIN, {})[example.id] = example
217
+
218
+ if draft == DRAFT_MAIN:
219
+ return draft_examples.get(DRAFT_MAIN, {})
220
+
221
+ if draft not in draft_examples:
222
+ raise ValueError(
223
+ f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
224
+
225
+ # Map the text of the draft to its id so we can dedupe with main.
226
+ draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
227
+
228
+ # Write each of examples from main to the draft examples only if the text does not appear in the
229
+ # draft.
230
+ for id, example in draft_examples[DRAFT_MAIN].items():
231
+ if example.text not in draft_text_ids:
232
+ draft_examples[draft][id] = example
233
+
234
+ return draft_examples[draft]
235
+
236
+
237
+ @dataclasses.dataclass
238
+ class ConceptModel:
239
+ """A concept model. Stores all concept model drafts and manages syncing."""
240
+ # The concept that this model is for.
241
+ namespace: str
242
+ concept_name: str
243
+
244
+ # The name of the embedding for this model.
245
+ embedding_name: str
246
+ version: int = 0
247
+
248
+ batch_size = 4096
249
+
250
+ # The following fields are excluded from JSON serialization, but still pickle-able.
251
+ # Maps a concept id to the embeddings.
252
+ _embeddings: dict[str, np.ndarray] = dataclasses.field(default_factory=dict)
253
+ _logistic_models: dict[DraftId, LogisticEmbeddingModel] = dataclasses.field(default_factory=dict)
254
+
255
+ def get_metrics(self, concept: Concept) -> Optional[ConceptMetrics]:
256
+ """Return the metrics for this model."""
257
+ return self._get_logistic_model(DRAFT_MAIN)._metrics
258
+
259
+ def score_embeddings(self, draft: DraftId, embeddings: np.ndarray) -> np.ndarray:
260
+ """Get the scores for the provided embeddings."""
261
+ return self._get_logistic_model(draft).score_embeddings(embeddings)
262
+
263
+ def coef(self, draft: DraftId) -> np.ndarray:
264
+ """Get the coefficients of the underlying ML model."""
265
+ return self._get_logistic_model(draft)._model.coef_.reshape(-1)
266
+
267
+ def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
268
+ """Get the logistic model for the provided draft."""
269
+ if draft not in self._logistic_models:
270
+ self._logistic_models[draft] = LogisticEmbeddingModel()
271
+ return self._logistic_models[draft]
272
+
273
+ def sync(self, concept: Concept) -> bool:
274
+ """Update the model with the latest labeled concept data."""
275
+ if concept.version == self.version:
276
+ # The model is up to date.
277
+ return False
278
+
279
+ concept_path = (f'{self.namespace}/{self.concept_name}/'
280
+ f'{self.embedding_name}')
281
+ with DebugTimer(f'Computing embeddings for "{concept_path}"'):
282
+ self._compute_embeddings(concept)
283
+
284
+ # Fit each of the drafts, sort by draft name for deterministic behavior.
285
+ for draft in concept.drafts():
286
+ examples = draft_examples(concept, draft)
287
+ embeddings = np.array([self._embeddings[id] for id in examples.keys()])
288
+ labels = [example.label for example in examples.values()]
289
+ model = self._get_logistic_model(draft)
290
+ with DebugTimer(f'Fitting model for "{concept_path}"'):
291
+ model.fit(embeddings, labels)
292
+
293
+ # Synchronize the model version with the concept version.
294
+ self.version = concept.version
295
+
296
+ return True
297
+
298
+ def _compute_embeddings(self, concept: Concept) -> None:
299
+ signal_cls = get_signal_cls(self.embedding_name)
300
+ if not signal_cls:
301
+ raise ValueError(f'Embedding signal "{self.embedding_name}" not found in the registry.')
302
+ embedding_signal = signal_cls()
303
+ if not isinstance(embedding_signal, TextEmbeddingSignal):
304
+ raise ValueError(f'Only text embedding signals are currently supported for concepts. '
305
+ f'"{self.embedding_name}" is a {type(embedding_signal)}.')
306
+
307
+ embed_fn = get_embed_fn(self.embedding_name, split=False)
308
+ concept_embeddings: dict[str, np.ndarray] = {}
309
+
310
+ examples = concept.data.items()
311
+ if not examples:
312
+ raise ValueError(f'Cannot sync concept "{concept.concept_name}". It has no examples.')
313
+
314
+ # Compute the embeddings for the examples with cache miss.
315
+ texts_of_missing_embeddings: dict[str, str] = {}
316
+ for id, example in examples:
317
+ if id in self._embeddings:
318
+ # Cache hit.
319
+ concept_embeddings[id] = self._embeddings[id]
320
+ else:
321
+ # Cache miss.
322
+ # TODO(smilkov): Support images.
323
+ texts_of_missing_embeddings[id] = example.text or ''
324
+
325
+ missing_ids = texts_of_missing_embeddings.keys()
326
+ missing_embeddings = embed_fn(list(texts_of_missing_embeddings.values()))
327
+
328
+ for id, (embedding,) in zip(missing_ids, missing_embeddings):
329
+ concept_embeddings[id] = embedding['vector']
330
+ self._embeddings = concept_embeddings
lilac/concepts/db_concept.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The concept database."""
2
+
3
+ import abc
4
+ import glob
5
+ import json
6
+ import os
7
+ import pathlib
8
+ import pickle
9
+ import shutil
10
+
11
+ # NOTE: We have to import the module for uuid so it can be mocked.
12
+ import uuid
13
+ from typing import Any, List, Optional, Union, cast
14
+
15
+ from pydantic import BaseModel
16
+ from typing_extensions import override
17
+
18
+ from ..auth import ConceptAuthorizationException, UserInfo
19
+ from ..env import data_path, env
20
+ from ..schema import SignalInputType
21
+ from ..signals.signal import get_signal_cls
22
+ from ..utils import delete_file, file_exists, open_file
23
+ from .concept import DRAFT_MAIN, Concept, ConceptModel, DraftId, Example, ExampleIn
24
+
25
+ CONCEPTS_DIR = 'concept'
26
+ CONCEPT_JSON_FILENAME = 'concept.json'
27
+
28
+
29
+ class ConceptNamespaceACL(BaseModel):
30
+ """The access control list for a namespace."""
31
+ # Whether the current user can read concepts in the namespace.
32
+ read: bool
33
+ # Whether the current user can add concepts to the namespace.
34
+ write: bool
35
+
36
+
37
+ class ConceptACL(BaseModel):
38
+ """The access control list for an individual concept."""
39
+ # Whether the current user can read the concept.
40
+ read: bool
41
+ # Whether the current user can edit the concept, including adding examples or deleting the
42
+ # concept.
43
+ write: bool
44
+
45
+
46
+ class ConceptInfo(BaseModel):
47
+ """Information about a concept."""
48
+ namespace: str
49
+ name: str
50
+ description: Optional[str] = None
51
+ type: SignalInputType
52
+ drafts: list[DraftId]
53
+
54
+ acls: ConceptACL
55
+
56
+
57
+ class ConceptUpdate(BaseModel):
58
+ """An update to a concept."""
59
+ # List of examples to be inserted.
60
+ insert: Optional[list[ExampleIn]] = []
61
+
62
+ # List of examples to be updated.
63
+ update: Optional[list[Example]] = []
64
+
65
+ # The ids of the examples to be removed.
66
+ remove: Optional[list[str]] = []
67
+
68
+
69
+ class ConceptDB(abc.ABC):
70
+ """Interface for the concept database."""
71
+
72
+ @abc.abstractmethod
73
+ def list(self, user: Optional[UserInfo] = None) -> list[ConceptInfo]:
74
+ """List all the concepts."""
75
+ pass
76
+
77
+ @abc.abstractmethod
78
+ def namespace_acls(self, namespace: str, user: Optional[UserInfo] = None) -> ConceptNamespaceACL:
79
+ """Return the ACL for a namespace."""
80
+ pass
81
+
82
+ @abc.abstractmethod
83
+ def concept_acls(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> ConceptACL:
84
+ """Return the ACL for a concept."""
85
+ pass
86
+
87
+ @abc.abstractmethod
88
+ def get(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> Optional[Concept]:
89
+ """Return a concept or None if there isn't one."""
90
+ pass
91
+
92
+ @abc.abstractmethod
93
+ def create(self,
94
+ namespace: str,
95
+ name: str,
96
+ type: SignalInputType,
97
+ description: Optional[str] = None,
98
+ user: Optional[UserInfo] = None) -> Concept:
99
+ """Create a concept.
100
+
101
+ Args:
102
+ namespace: The namespace of the concept.
103
+ name: The name of the concept.
104
+ type: The input type of the concept.
105
+ description: The description of the concept.
106
+ user: The user creating the concept, if authentication is enabled.
107
+ """
108
+ pass
109
+
110
+ @abc.abstractmethod
111
+ def edit(self,
112
+ namespace: str,
113
+ name: str,
114
+ change: ConceptUpdate,
115
+ user: Optional[UserInfo] = None) -> Concept:
116
+ """Edit a concept. If the concept doesn't exist, throw an error."""
117
+ pass
118
+
119
+ @abc.abstractmethod
120
+ def remove(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> None:
121
+ """Remove a concept."""
122
+ pass
123
+
124
+ @abc.abstractmethod
125
+ def merge_draft(self,
126
+ namespace: str,
127
+ name: str,
128
+ draft: DraftId,
129
+ user: Optional[UserInfo] = None) -> Concept:
130
+ """Merge a draft concept.."""
131
+ pass
132
+
133
+
134
+ class ConceptModelDB(abc.ABC):
135
+ """Interface for the concept model database."""
136
+
137
+ _concept_db: ConceptDB
138
+
139
+ def __init__(self, concept_db: ConceptDB) -> None:
140
+ self._concept_db = concept_db
141
+
142
+ @abc.abstractmethod
143
+ def create(self,
144
+ namespace: str,
145
+ concept_name: str,
146
+ embedding_name: str,
147
+ user: Optional[UserInfo] = None) -> ConceptModel:
148
+ """Create the concept model."""
149
+ pass
150
+
151
+ @abc.abstractmethod
152
+ def get(self,
153
+ namespace: str,
154
+ concept_name: str,
155
+ embedding_name: str,
156
+ user: Optional[UserInfo] = None) -> Optional[ConceptModel]:
157
+ """Get the model associated with the provided concept the embedding.
158
+
159
+ Returns None if the model does not exist.
160
+ """
161
+ pass
162
+
163
+ @abc.abstractmethod
164
+ def _save(self, model: ConceptModel) -> None:
165
+ """Save the concept model."""
166
+ pass
167
+
168
+ def in_sync(self, model: ConceptModel, user: Optional[UserInfo] = None) -> bool:
169
+ """Return True if the model is up to date with the concept."""
170
+ concept = self._concept_db.get(model.namespace, model.concept_name, user=user)
171
+ if not concept:
172
+ raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
173
+ return concept.version == model.version
174
+
175
+ def sync(self, model: ConceptModel, user: Optional[UserInfo] = None) -> bool:
176
+ """Sync the concept model. Returns true if the model was updated."""
177
+ concept = self._concept_db.get(model.namespace, model.concept_name, user=user)
178
+ if not concept:
179
+ raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
180
+ model_updated = model.sync(concept)
181
+ if model_updated:
182
+ self._save(model)
183
+ return model_updated
184
+
185
+ @abc.abstractmethod
186
+ def remove(self, namespace: str, concept_name: str, embedding_name: str) -> None:
187
+ """Remove the model of a concept."""
188
+ pass
189
+
190
+ @abc.abstractmethod
191
+ def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
192
+ """List all the models associated with a concept."""
193
+ pass
194
+
195
+
196
+ class DiskConceptModelDB(ConceptModelDB):
197
+ """Interface for the concept model database."""
198
+
199
+ def __init__(self,
200
+ concept_db: ConceptDB,
201
+ base_dir: Optional[Union[str, pathlib.Path]] = None) -> None:
202
+ super().__init__(concept_db)
203
+ self._base_dir = base_dir
204
+
205
+ def _get_base_dir(self) -> str:
206
+ return str(self._base_dir) if self._base_dir else data_path()
207
+
208
+ @override
209
+ def create(self,
210
+ namespace: str,
211
+ concept_name: str,
212
+ embedding_name: str,
213
+ user: Optional[UserInfo] = None) -> ConceptModel:
214
+ if self.get(namespace, concept_name, embedding_name, user=user):
215
+ raise ValueError('Concept model already exists.')
216
+ concept = self._concept_db.get(namespace, concept_name, user=user)
217
+ if not concept:
218
+ raise ValueError(f'Concept "{namespace}/{concept_name}" does not exist.')
219
+
220
+ return ConceptModel(
221
+ namespace=namespace, concept_name=concept_name, embedding_name=embedding_name)
222
+
223
+ @override
224
+ def get(self,
225
+ namespace: str,
226
+ concept_name: str,
227
+ embedding_name: str,
228
+ user: Optional[UserInfo] = None) -> Optional[ConceptModel]:
229
+ # Make sure the concept exists.
230
+ concept = self._concept_db.get(namespace, concept_name, user=user)
231
+ if not concept:
232
+ raise ValueError(f'Concept "{namespace}/{concept_name}" does not exist.')
233
+
234
+ # Make sure that the embedding signal exists.
235
+ if not get_signal_cls(embedding_name):
236
+ raise ValueError(f'Embedding signal "{embedding_name}" not found in the registry.')
237
+
238
+ concept_model_path = _concept_model_path(self._get_base_dir(), namespace, concept_name,
239
+ embedding_name)
240
+ if not file_exists(concept_model_path):
241
+ return None
242
+
243
+ with open_file(concept_model_path, 'rb') as f:
244
+ return pickle.load(f)
245
+
246
+ def _save(self, model: ConceptModel) -> None:
247
+ """Save the concept model."""
248
+ concept_model_path = _concept_model_path(self._get_base_dir(), model.namespace,
249
+ model.concept_name, model.embedding_name)
250
+ with open_file(concept_model_path, 'wb') as f:
251
+ pickle.dump(model, f)
252
+
253
+ @override
254
+ def remove(self,
255
+ namespace: str,
256
+ concept_name: str,
257
+ embedding_name: str,
258
+ user: Optional[UserInfo] = None) -> None:
259
+ concept_model_path = _concept_model_path(self._get_base_dir(), namespace, concept_name,
260
+ embedding_name)
261
+
262
+ if not file_exists(concept_model_path):
263
+ raise ValueError(f'Concept model {namespace}/{concept_name}/{embedding_name} does not exist.')
264
+
265
+ delete_file(concept_model_path)
266
+
267
+ @override
268
+ def get_models(self,
269
+ namespace: str,
270
+ concept_name: str,
271
+ user: Optional[UserInfo] = None) -> list[ConceptModel]:
272
+ """List all the models associated with a concept."""
273
+ model_files = glob.iglob(
274
+ os.path.join(get_concept_output_dir(self._get_base_dir(), namespace, concept_name), '*.pkl'))
275
+ models: list[ConceptModel] = []
276
+ for model_file in model_files:
277
+ embedding_name = os.path.basename(model_file)[:-len('.pkl')]
278
+ model = self.get(namespace, concept_name, embedding_name, user=user)
279
+ if model:
280
+ models.append(model)
281
+ return models
282
+
283
+
284
+ def get_concept_output_dir(base_dir: str, namespace: str, name: str) -> str:
285
+ """Return the output directory for a given concept."""
286
+ return os.path.join(base_dir, CONCEPTS_DIR, namespace, name)
287
+
288
+
289
+ def _concept_json_path(base_dir: str, namespace: str, name: str) -> str:
290
+ return os.path.join(get_concept_output_dir(base_dir, namespace, name), CONCEPT_JSON_FILENAME)
291
+
292
+
293
+ def _concept_model_path(base_dir: str, namespace: str, concept_name: str,
294
+ embedding_name: str) -> str:
295
+
296
+ return os.path.join(
297
+ get_concept_output_dir(base_dir, namespace, concept_name), f'{embedding_name}.pkl')
298
+
299
+
300
+ class DiskConceptDB(ConceptDB):
301
+ """A concept database."""
302
+
303
+ def __init__(self, base_dir: Optional[Union[str, pathlib.Path]] = None) -> None:
304
+ self._base_dir = base_dir
305
+
306
+ def _get_base_dir(self) -> str:
307
+ return str(self._base_dir) if self._base_dir else data_path()
308
+
309
+ @override
310
+ def namespace_acls(self, namespace: str, user: Optional[UserInfo] = None) -> ConceptNamespaceACL:
311
+ if not env('LILAC_AUTH_ENABLED'):
312
+ return ConceptNamespaceACL(read=True, write=True)
313
+
314
+ if namespace == 'lilac':
315
+ return ConceptNamespaceACL(read=True, write=False)
316
+ if user and user.id == namespace:
317
+ return ConceptNamespaceACL(read=True, write=True)
318
+
319
+ return ConceptNamespaceACL(read=False, write=False)
320
+
321
+ @override
322
+ def concept_acls(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> ConceptACL:
323
+ namespace_acls = self.namespace_acls(namespace, user=user)
324
+ # Concept ACL inherit from the namespace ACL. We currently don't have concept-specific
325
+ # ACL.
326
+ return ConceptACL(read=namespace_acls.read, write=namespace_acls.write)
327
+
328
+ @override
329
+ def list(self, user: Optional[UserInfo] = None) -> list[ConceptInfo]:
330
+ namespaces: Optional[list[str]] = None
331
+ if env('LILAC_AUTH_ENABLED'):
332
+ namespaces = ['lilac']
333
+ if user:
334
+ namespaces += [user.id]
335
+
336
+ # Read the concepts and return a ConceptInfo containing the namespace and name.
337
+ concept_infos = []
338
+ for root, _, files in os.walk(self._get_base_dir()):
339
+ for file in files:
340
+ if file == CONCEPT_JSON_FILENAME:
341
+ namespace, name = root.split('/')[-2:]
342
+ if namespaces and namespace not in namespaces:
343
+ # Ignore concepts that are not in the namespace, if provided.
344
+ continue
345
+
346
+ concept = cast(Concept, self.get(namespace, name, user=user))
347
+ concept_infos.append(
348
+ ConceptInfo(
349
+ namespace=namespace,
350
+ name=name,
351
+ description=concept.description,
352
+ type=SignalInputType.TEXT,
353
+ drafts=concept.drafts(),
354
+ acls=self.concept_acls(namespace, name, user=user)))
355
+
356
+ return concept_infos
357
+
358
+ @override
359
+ def get(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> Optional[Concept]:
360
+ # If the user does not have access to the concept, return None.
361
+ acls = self.concept_acls(namespace, name, user=user)
362
+ if not acls.read:
363
+ raise ConceptAuthorizationException(
364
+ f'Concept "{namespace}/{name}" does not exist or user does not have access.')
365
+
366
+ concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
367
+ if not file_exists(concept_json_path):
368
+ return None
369
+
370
+ with open_file(concept_json_path) as f:
371
+ obj: dict[str, Any] = json.load(f)
372
+ if 'namespace' not in obj:
373
+ obj['namespace'] = namespace
374
+ return Concept.parse_obj(obj)
375
+
376
+ @override
377
+ def create(self,
378
+ namespace: str,
379
+ name: str,
380
+ type: SignalInputType,
381
+ description: Optional[str] = None,
382
+ user: Optional[UserInfo] = None) -> Concept:
383
+ """Create a concept."""
384
+ # If the user does not have access to the write to the concept namespace, throw.
385
+ acls = self.namespace_acls(namespace, user=user)
386
+ if not acls.write:
387
+ raise ConceptAuthorizationException(
388
+ f'Concept namespace "{namespace}" does not exist or user does not have access.')
389
+
390
+ concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
391
+ if file_exists(concept_json_path):
392
+ raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
393
+
394
+ concept = Concept(
395
+ namespace=namespace, concept_name=name, type=type, data={}, description=description)
396
+ self._save(concept)
397
+ return concept
398
+
399
+ def _validate_examples(self, examples: List[Union[ExampleIn, Example]],
400
+ type: SignalInputType) -> None:
401
+ for example in examples:
402
+ inferred_type = 'text' if example.text else 'img'
403
+ if inferred_type != type:
404
+ raise ValueError(f'Example type "{inferred_type}" does not match concept type "{type}".')
405
+
406
+ @override
407
+ def edit(self,
408
+ namespace: str,
409
+ name: str,
410
+ change: ConceptUpdate,
411
+ user: Optional[UserInfo] = None) -> Concept:
412
+ # If the user does not have access to the concept, return None.
413
+ acls = self.concept_acls(namespace, name, user=user)
414
+ if not acls.write:
415
+ raise ConceptAuthorizationException(
416
+ f'Concept "{namespace}/{name}" does not exist or user does not have access.')
417
+
418
+ concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
419
+
420
+ if not file_exists(concept_json_path):
421
+ raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist. '
422
+ 'Please call create() first.')
423
+
424
+ inserted_points = change.insert or []
425
+ updated_points = change.update or []
426
+ removed_points = change.remove or []
427
+
428
+ concept = cast(Concept, self.get(namespace, name, user=user))
429
+
430
+ self._validate_examples([*inserted_points, *updated_points], concept.type)
431
+
432
+ for remove_example in removed_points:
433
+ if remove_example not in concept.data:
434
+ raise ValueError(f'Example with id "{remove_example}" does not exist.')
435
+ concept.data.pop(remove_example)
436
+
437
+ for example in inserted_points:
438
+ id = uuid.uuid4().hex
439
+ concept.data[id] = Example(id=id, **example.dict())
440
+
441
+ for example in updated_points:
442
+ if example.id not in concept.data:
443
+ raise ValueError(f'Example with id "{example.id}" does not exist.')
444
+
445
+ # Remove the old example and make a new one with a new id to keep it functional.
446
+ concept.data.pop(example.id)
447
+ concept.data[example.id] = example.copy()
448
+
449
+ concept.version += 1
450
+
451
+ self._save(concept)
452
+
453
+ return concept
454
+
455
+ def _save(self, concept: Concept) -> None:
456
+ concept_json_path = _concept_json_path(self._get_base_dir(), concept.namespace,
457
+ concept.concept_name)
458
+ with open_file(concept_json_path, 'w') as f:
459
+ f.write(concept.json(exclude_none=True, indent=2, exclude_defaults=True))
460
+
461
+ @override
462
+ def remove(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> None:
463
+ # If the user does not have access to the concept, return None.
464
+ acls = self.concept_acls(namespace, name, user=user)
465
+ if not acls.write:
466
+ raise ConceptAuthorizationException(
467
+ f'Concept "{namespace}/{name}" does not exist or user does not have access.')
468
+
469
+ concept_dir = get_concept_output_dir(self._get_base_dir(), namespace, name)
470
+
471
+ if not file_exists(concept_dir):
472
+ raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
473
+
474
+ shutil.rmtree(concept_dir, ignore_errors=True)
475
+
476
+ @override
477
+ def merge_draft(self,
478
+ namespace: str,
479
+ name: str,
480
+ draft: DraftId,
481
+ user: Optional[UserInfo] = None) -> Concept:
482
+ """Merge a draft concept."""
483
+ # If the user does not have access to the concept, return None.
484
+ acls = self.concept_acls(namespace, name, user=user)
485
+ if not acls.write:
486
+ raise ConceptAuthorizationException(
487
+ f'Concept "{namespace}/{name}" does not exist or user does not have access.')
488
+
489
+ concept = self.get(namespace, name, user=user)
490
+ if not concept:
491
+ raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
492
+
493
+ if draft == DRAFT_MAIN:
494
+ return concept
495
+
496
+ # Map the text of examples in main so we can remove them if they are duplicates.
497
+ main_text_ids: dict[Optional[str], str] = {
498
+ example.text: id for id, example in concept.data.items() if example.draft == DRAFT_MAIN
499
+ }
500
+
501
+ draft_examples: dict[str, Example] = {
502
+ id: example for id, example in concept.data.items() if example.draft == draft
503
+ }
504
+ for example in draft_examples.values():
505
+ example.draft = DRAFT_MAIN
506
+ # Remove duplicates in main.
507
+ main_text_id = main_text_ids.get(example.text)
508
+ if main_text_id:
509
+ del concept.data[main_text_id]
510
+
511
+ concept.version += 1
512
+
513
+ self._save(concept)
514
+
515
+ return concept
516
+
517
+
518
+ # A singleton concept database.
519
+ DISK_CONCEPT_DB = DiskConceptDB()
520
+ DISK_CONCEPT_MODEL_DB = DiskConceptModelDB(DISK_CONCEPT_DB)
lilac/config.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configurations for a dataset run."""
2
+ from typing import Optional
3
+
4
+ from pydantic import BaseModel, validator
5
+
6
+ from .data.dataset import DatasetSettings
7
+ from .schema import Path, PathTuple, normalize_path
8
+ from .signals.signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal
9
+ from .sources.source import Source
10
+ from .sources.source_registry import resolve_source
11
+
12
+
13
+ class SignalConfig(BaseModel):
14
+ """Configures a signal on a source path."""
15
+ path: PathTuple
16
+ signal: Signal
17
+
18
+ @validator('path', pre=True)
19
+ def parse_path(cls, path: Path) -> PathTuple:
20
+ """Parse a path."""
21
+ return normalize_path(path)
22
+
23
+ @validator('signal', pre=True)
24
+ def parse_signal(cls, signal: dict) -> Signal:
25
+ """Parse a signal to its specific subclass instance."""
26
+ return resolve_signal(signal)
27
+
28
+
29
+ class EmbeddingConfig(BaseModel):
30
+ """Configures an embedding on a source path."""
31
+ path: PathTuple
32
+ embedding: str
33
+
34
+ @validator('path', pre=True)
35
+ def parse_path(cls, path: Path) -> PathTuple:
36
+ """Parse a path."""
37
+ return normalize_path(path)
38
+
39
+ @validator('embedding', pre=True)
40
+ def validate_embedding(cls, embedding: str) -> str:
41
+ """Validate the embedding is registered."""
42
+ get_signal_by_type(embedding, TextEmbeddingSignal)
43
+ return embedding
44
+
45
+
46
+ class DatasetConfig(BaseModel):
47
+ """Configures a dataset with a source and transformations."""
48
+ # The namespace and name of the dataset.
49
+ namespace: str
50
+ name: str
51
+
52
+ # The source configuration.
53
+ source: Source
54
+
55
+ # Model configuration: embeddings and signals on paths.
56
+ embeddings: Optional[list[EmbeddingConfig]]
57
+ # When defined, uses this list of signals instead of running all signals.
58
+ signals: Optional[list[SignalConfig]]
59
+
60
+ # Dataset settings, default embeddings and UI settings like media paths.
61
+ settings: Optional[DatasetSettings]
62
+
63
+ @validator('source', pre=True)
64
+ def parse_source(cls, source: dict) -> Source:
65
+ """Parse a source to its specific subclass instance."""
66
+ return resolve_source(source)
67
+
68
+
69
+ class Config(BaseModel):
70
+ """Configures a set of datasets for a lilac instance."""
71
+ datasets: list[DatasetConfig]
72
+
73
+ # When defined, uses this list of signals to run over every dataset, over all media paths, unless
74
+ # signals is overridden by a specific dataset.
75
+ signals: list[Signal] = []
76
+
77
+ @validator('signals', pre=True)
78
+ def parse_signal(cls, signals: list[dict]) -> list[Signal]:
79
+ """Parse alist of signals to their specific subclass instances."""
80
+ return [resolve_signal(signal) for signal in signals]
lilac/conftest.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fixtures for dataset tests."""
2
+ import os
3
+ import pathlib
4
+ from typing import Generator, Optional, Type
5
+
6
+ import pytest
7
+ from pytest_mock import MockerFixture
8
+
9
+ from .data.dataset import Dataset
10
+ from .data.dataset_duckdb import DatasetDuckDB
11
+ from .data.dataset_test_utils import make_dataset
12
+ from .db_manager import set_default_dataset_cls
13
+ from .schema import Item, Schema
14
+
15
+
16
+ @pytest.fixture(scope='function', params=[DatasetDuckDB])
17
+ def make_test_data(tmp_path: pathlib.Path, mocker: MockerFixture,
18
+ request: pytest.FixtureRequest) -> Generator:
19
+ """A pytest fixture for creating temporary test datasets."""
20
+ mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)})
21
+ dataset_cls: Type[Dataset] = request.param
22
+ set_default_dataset_cls(dataset_cls)
23
+
24
+ def _make_test_data(items: list[Item], schema: Optional[Schema] = None) -> Dataset:
25
+ return make_dataset(dataset_cls, tmp_path, items, schema)
26
+
27
+ # Return the factory for datasets that test methods can use.
28
+ yield _make_test_data
lilac/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import Column, ConceptQuery, KeywordQuery, Search, SemanticQuery
2
+
3
+ __all__ = [
4
+ 'Column',
5
+ 'Search',
6
+ 'KeywordQuery',
7
+ 'ConceptQuery',
8
+ 'SemanticQuery',
9
+ ]
lilac/data/dataset.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The interface for the database."""
2
+ import abc
3
+ import enum
4
+ import pathlib
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from datetime import datetime
7
+ from typing import Any, Iterator, Literal, Optional, Sequence, Union
8
+
9
+ import pandas as pd
10
+ from pydantic import BaseModel
11
+ from pydantic import Field as PydanticField
12
+ from pydantic import StrictBool, StrictBytes, StrictFloat, StrictInt, StrictStr, validator
13
+
14
+ from ..auth import UserInfo
15
+ from ..schema import VALUE_KEY, Bin, DataType, Path, PathTuple, Schema, normalize_path
16
+ from ..signals.signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal
17
+ from ..tasks import TaskStepId
18
+
19
+ # Threshold for rejecting certain queries (e.g. group by) for columns with large cardinality.
20
+ TOO_MANY_DISTINCT = 1_000_000
21
+
22
+
23
+ class SelectRowsResult:
24
+ """The result of a select rows query."""
25
+
26
+ def __init__(self, df: pd.DataFrame, total_num_rows: int) -> None:
27
+ """Initialize the result."""
28
+ self._df = df
29
+ self.total_num_rows = total_num_rows
30
+
31
+ def __iter__(self) -> Iterator:
32
+ return (row.to_dict() for _, row in self._df.iterrows())
33
+
34
+ def df(self) -> pd.DataFrame:
35
+ """Convert the result to a pandas DataFrame."""
36
+ return self._df
37
+
38
+
39
+ class StatsResult(BaseModel):
40
+ """The result of a stats() query."""
41
+ path: PathTuple
42
+ # The number of leaf values.
43
+ total_count: int
44
+ # The approximate number of distinct leaf values.
45
+ approx_count_distinct: int
46
+
47
+ # Defined for ordinal features.
48
+ min_val: Optional[Union[float, datetime]] = None
49
+ max_val: Optional[Union[float, datetime]] = None
50
+
51
+ # Defined for text features.
52
+ avg_text_length: Optional[float] = None
53
+
54
+
55
+ class MediaResult(BaseModel):
56
+ """The result of a media() query."""
57
+ data: bytes
58
+
59
+
60
+ class BinaryOp(str, enum.Enum):
61
+ """The comparison operator between a column and a feature value."""
62
+ EQUALS = 'equals'
63
+ NOT_EQUAL = 'not_equal'
64
+ GREATER = 'greater'
65
+ GREATER_EQUAL = 'greater_equal'
66
+ LESS = 'less'
67
+ LESS_EQUAL = 'less_equal'
68
+
69
+
70
+ SearchType = Union[Literal['keyword'], Literal['semantic'], Literal['concept']]
71
+
72
+
73
+ class UnaryOp(str, enum.Enum):
74
+ """A unary operator on a feature."""
75
+ EXISTS = 'exists'
76
+
77
+
78
+ class ListOp(str, enum.Enum):
79
+ """A list operator on a feature."""
80
+ IN = 'in'
81
+
82
+
83
+ class SortOrder(str, enum.Enum):
84
+ """The sort order for a database query."""
85
+ DESC = 'DESC'
86
+ ASC = 'ASC'
87
+
88
+
89
+ class GroupsSortBy(str, enum.Enum):
90
+ """The sort for groups queries.
91
+
92
+ Either "count" which sorts by the count of feature value, or "value" which sorts by the
93
+ feature value itself.
94
+ """
95
+ COUNT = 'count'
96
+ VALUE = 'value'
97
+
98
+
99
+ class SortResult(BaseModel):
100
+ """The information about what is sorted after combining searches and explicit sorts."""
101
+ # The column that was sorted.
102
+ path: PathTuple
103
+ # The sort order.
104
+ order: SortOrder
105
+ # The alias of the column if it was aliased.
106
+ alias: Optional[str] = None
107
+ # The search index if the sort is by a search.
108
+ search_index: Optional[int] = None
109
+
110
+
111
+ class SearchResultInfo(BaseModel):
112
+ """The resulting sort order returned by the select rows schema."""
113
+ # The input path to the search.
114
+ search_path: PathTuple
115
+ # The resulting column that was searched.
116
+ result_path: PathTuple
117
+ # The alias of the UDF.
118
+ alias: Optional[str] = None
119
+
120
+
121
+ class SelectRowsSchemaUDF(BaseModel):
122
+ """The UDF for a select rows schema query."""
123
+ path: PathTuple
124
+ alias: Optional[str] = None
125
+
126
+
127
+ class SelectRowsSchemaResult(BaseModel):
128
+ """The result of a select rows schema query."""
129
+ data_schema: Schema
130
+ udfs: list[SelectRowsSchemaUDF] = []
131
+ search_results: list[SearchResultInfo] = []
132
+ sorts: Optional[list[SortResult]] = None
133
+
134
+
135
+ class Column(BaseModel):
136
+ """A column in the dataset."""
137
+ path: PathTuple
138
+ alias: Optional[str] = None # This is the renamed column during querying and response.
139
+
140
+ # Defined when the feature is another column.
141
+ signal_udf: Optional[Signal] = None
142
+
143
+ class Config:
144
+ smart_union = True
145
+
146
+ def __init__(self,
147
+ path: Path,
148
+ alias: Optional[str] = None,
149
+ signal_udf: Optional[Signal] = None,
150
+ **kwargs: Any):
151
+ """Initialize a column. We override __init__ to allow positional arguments for brevity."""
152
+ super().__init__(path=normalize_path(path), alias=alias, signal_udf=signal_udf, **kwargs)
153
+
154
+ @validator('signal_udf', pre=True)
155
+ def parse_signal_udf(cls, signal_udf: Optional[dict]) -> Optional[Signal]:
156
+ """Parse a signal to its specific subclass instance."""
157
+ if not signal_udf:
158
+ return None
159
+ return resolve_signal(signal_udf)
160
+
161
+
162
+ ColumnId = Union[Path, Column]
163
+
164
+
165
+ class DatasetUISettings(BaseModel):
166
+ """The UI persistent settings for a dataset."""
167
+ media_paths: list[PathTuple] = []
168
+ markdown_paths: list[PathTuple] = []
169
+
170
+ @validator('media_paths', pre=True)
171
+ def parse_media_paths(cls, media_paths: list) -> list:
172
+ """Parse a path, ensuring it is a tuple."""
173
+ return [normalize_path(path) for path in media_paths]
174
+
175
+
176
+ class DatasetSettings(BaseModel):
177
+ """The persistent settings for a dataset."""
178
+ ui: Optional[DatasetUISettings] = None
179
+ preferred_embedding: Optional[str] = None
180
+
181
+
182
+ class DatasetManifest(BaseModel):
183
+ """The manifest for a dataset."""
184
+ namespace: str
185
+ dataset_name: str
186
+ data_schema: Schema
187
+ # Number of items in the dataset.
188
+ num_items: int
189
+
190
+
191
+ def column_from_identifier(column: ColumnId) -> Column:
192
+ """Create a column from a column identifier."""
193
+ if isinstance(column, Column):
194
+ return column.copy()
195
+ return Column(path=column)
196
+
197
+
198
+ FeatureValue = Union[StrictInt, StrictFloat, StrictBool, StrictStr, StrictBytes, datetime]
199
+ FeatureListValue = list[StrictStr]
200
+ BinaryFilterTuple = tuple[Path, BinaryOp, FeatureValue]
201
+ ListFilterTuple = tuple[Path, ListOp, FeatureListValue]
202
+ UnaryFilterTuple = tuple[Path, UnaryOp]
203
+
204
+ FilterOp = Union[BinaryOp, UnaryOp, ListOp]
205
+
206
+
207
+ class SelectGroupsResult(BaseModel):
208
+ """The result of a select groups query."""
209
+ too_many_distinct: bool
210
+ counts: list[tuple[Optional[FeatureValue], int]]
211
+ bins: Optional[list[Bin]] = None
212
+
213
+
214
+ class Filter(BaseModel):
215
+ """A filter on a column."""
216
+ path: PathTuple
217
+ op: FilterOp
218
+ value: Optional[Union[FeatureValue, FeatureListValue]] = None
219
+
220
+
221
+ FilterLike = Union[Filter, BinaryFilterTuple, UnaryFilterTuple, ListFilterTuple]
222
+
223
+ SearchValue = StrictStr
224
+
225
+
226
+ class KeywordQuery(BaseModel):
227
+ """A keyword search query on a column."""
228
+ type: Literal['keyword'] = 'keyword'
229
+ search: SearchValue
230
+
231
+
232
+ class SemanticQuery(BaseModel):
233
+ """A semantic search on a column."""
234
+ type: Literal['semantic'] = 'semantic'
235
+ search: SearchValue
236
+ embedding: str
237
+
238
+
239
+ class ConceptQuery(BaseModel):
240
+ """A concept search query on a column."""
241
+ type: Literal['concept'] = 'concept'
242
+ concept_namespace: str
243
+ concept_name: str
244
+ embedding: str
245
+
246
+
247
+ class Search(BaseModel):
248
+ """A search on a column."""
249
+ path: Path
250
+ query: Union[KeywordQuery, SemanticQuery, ConceptQuery] = PydanticField(discriminator='type')
251
+
252
+
253
+ class Dataset(abc.ABC):
254
+ """The database implementation to query a dataset."""
255
+
256
+ namespace: str
257
+ dataset_name: str
258
+
259
+ def __init__(self, namespace: str, dataset_name: str):
260
+ """Initialize a dataset.
261
+
262
+ Args:
263
+ namespace: The dataset namespace.
264
+ dataset_name: The dataset name.
265
+ """
266
+ self.namespace = namespace
267
+ self.dataset_name = dataset_name
268
+
269
+ @abc.abstractmethod
270
+ def delete(self) -> None:
271
+ """Deletes the dataset."""
272
+ pass
273
+
274
+ @abc.abstractmethod
275
+ def manifest(self) -> DatasetManifest:
276
+ """Return the manifest for the dataset."""
277
+ pass
278
+
279
+ @abc.abstractmethod
280
+ def settings(self) -> DatasetSettings:
281
+ """Return the persistent settings for the dataset."""
282
+ pass
283
+
284
+ @abc.abstractmethod
285
+ def update_settings(self, settings: DatasetSettings) -> None:
286
+ """Update the settings for the dataset."""
287
+ pass
288
+
289
+ @abc.abstractmethod
290
+ def compute_signal(self,
291
+ signal: Signal,
292
+ leaf_path: Path,
293
+ task_step_id: Optional[TaskStepId] = None) -> None:
294
+ """Compute a signal for a column.
295
+
296
+ Args:
297
+ signal: The signal to compute over the given columns.
298
+ leaf_path: The leaf path to compute the signal on.
299
+ task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the
300
+ progress of the task.
301
+ """
302
+ pass
303
+
304
+ def compute_embedding(self,
305
+ embedding: str,
306
+ path: Path,
307
+ task_step_id: Optional[TaskStepId] = None) -> None:
308
+ """Compute an embedding for a given field path."""
309
+ signal = get_signal_by_type(embedding, TextEmbeddingSignal)()
310
+ self.compute_signal(signal, path)
311
+
312
+ @abc.abstractmethod
313
+ def delete_signal(self, signal_path: Path) -> None:
314
+ """Delete a computed signal from the dataset.
315
+
316
+ Args:
317
+ signal_path: The path holding the computed data of the signal.
318
+ """
319
+ pass
320
+
321
+ @abc.abstractmethod
322
+ def select_groups(
323
+ self,
324
+ leaf_path: Path,
325
+ filters: Optional[Sequence[FilterLike]] = None,
326
+ sort_by: Optional[GroupsSortBy] = None,
327
+ sort_order: Optional[SortOrder] = SortOrder.DESC,
328
+ limit: Optional[int] = None,
329
+ bins: Optional[Union[Sequence[Bin], Sequence[float]]] = None) -> SelectGroupsResult:
330
+ """Select grouped columns to power a histogram.
331
+
332
+ Args:
333
+ leaf_path: The leaf path to group by. The path can be a dot-seperated string path, or a tuple
334
+ of fields.
335
+ filters: The filters to apply to the query.
336
+ sort_by: What to sort by, either "count" or "value".
337
+ sort_order: The sort order.
338
+ limit: The maximum number of rows to return.
339
+ bins: The bins to use when bucketizing a float column.
340
+
341
+ Returns
342
+ A `SelectGroupsResult` iterator where each row is a group.
343
+ """
344
+ raise NotImplementedError
345
+
346
+ @abc.abstractmethod
347
+ def select_rows(self,
348
+ columns: Optional[Sequence[ColumnId]] = None,
349
+ searches: Optional[Sequence[Search]] = None,
350
+ filters: Optional[Sequence[FilterLike]] = None,
351
+ sort_by: Optional[Sequence[Path]] = None,
352
+ sort_order: Optional[SortOrder] = SortOrder.DESC,
353
+ limit: Optional[int] = 100,
354
+ offset: Optional[int] = 0,
355
+ task_step_id: Optional[TaskStepId] = None,
356
+ resolve_span: bool = False,
357
+ combine_columns: bool = False,
358
+ user: Optional[UserInfo] = None) -> SelectRowsResult:
359
+ """Select grouped columns to power a histogram.
360
+
361
+ Args:
362
+ columns: The columns to select. A column is an instance of `Column` which can either
363
+ define a path to a feature, or a column with an applied Transform, e.g. a Concept. If none,
364
+ it selects all columns.
365
+ searches: The searches to apply to the query.
366
+ filters: The filters to apply to the query.
367
+ sort_by: An ordered list of what to sort by. When defined, this is a list of aliases of column
368
+ names defined by the "alias" field in Column. If no alias is provided for a column, an
369
+ automatic alias is generated by combining each path element with a "."
370
+ For example: e.g. ('person', 'name') => person.name. For columns that are transform columns,
371
+ an alias must be provided explicitly. When sorting by a (nested) list of values, the sort
372
+ takes the minumum value when `sort_order` is `ASC`, and the maximum value when `sort_order`
373
+ is `DESC`.
374
+ sort_order: The sort order.
375
+ limit: The maximum number of rows to return.
376
+ offset: The offset to start returning rows from.
377
+ task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the
378
+ progress.
379
+ resolve_span: Whether to resolve the span of the row.
380
+ combine_columns: Whether to combine columns into a single object. The object will be pruned
381
+ to only include sub-fields that correspond to the requested columns.
382
+ user: The authenticated user, if auth is enabled and the user is logged in. This is used to
383
+ apply ACL to the query, especially for concepts.
384
+
385
+ Returns
386
+ A SelectRowsResult iterator with rows of `Item`s.
387
+ """
388
+ pass
389
+
390
+ @abc.abstractmethod
391
+ def select_rows_schema(self,
392
+ columns: Optional[Sequence[ColumnId]] = None,
393
+ sort_by: Optional[Sequence[Path]] = None,
394
+ sort_order: Optional[SortOrder] = SortOrder.DESC,
395
+ searches: Optional[Sequence[Search]] = None,
396
+ combine_columns: bool = False) -> SelectRowsSchemaResult:
397
+ """Returns the schema of the result of `select_rows` above with the same arguments."""
398
+ pass
399
+
400
+ @abc.abstractmethod
401
+ def stats(self, leaf_path: Path) -> StatsResult:
402
+ """Compute stats for a leaf path.
403
+
404
+ Args:
405
+ leaf_path: The leaf path to compute stats for.
406
+
407
+ Returns
408
+ A StatsResult.
409
+ """
410
+ pass
411
+
412
+ @abc.abstractmethod
413
+ def media(self, item_id: str, leaf_path: Path) -> MediaResult:
414
+ """Return the media for a leaf path.
415
+
416
+ Args:
417
+ item_id: The item id to get media for.
418
+ leaf_path: The leaf path for the media.
419
+
420
+ Returns
421
+ A MediaResult.
422
+ """
423
+ pass
424
+
425
+ @abc.abstractmethod
426
+ def to_json(self, filepath: Union[str, pathlib.Path], jsonl: bool = True) -> None:
427
+ """Export the dataset to a JSON file.
428
+
429
+ Args:
430
+ filepath: The path to the file to export to.
431
+ jsonl: Whether to export to JSONL or JSON.
432
+ """
433
+ pass
434
+
435
+ @abc.abstractmethod
436
+ def to_pandas(self) -> pd.DataFrame:
437
+ """Export the dataset to a pandas DataFrame."""
438
+ pass
439
+
440
+ @abc.abstractmethod
441
+ def to_parquet(self, filepath: Union[str, pathlib.Path]) -> None:
442
+ """Export the dataset to a parquet file.
443
+
444
+ Args:
445
+ filepath: The path to the file to export to.
446
+ """
447
+ pass
448
+
449
+ @abc.abstractmethod
450
+ def to_csv(self, filepath: Union[str, pathlib.Path]) -> None:
451
+ """Export the dataset to a csv file.
452
+
453
+ Args:
454
+ filepath: The path to the file to export to.
455
+ """
456
+ pass
457
+
458
+
459
+ def default_settings(dataset: Dataset) -> DatasetSettings:
460
+ """Gets the default settings for a dataset."""
461
+ schema = dataset.manifest().data_schema
462
+ leaf_paths = [path for path, field in schema.leafs.items() if field.dtype == DataType.STRING]
463
+ pool = ThreadPoolExecutor()
464
+ stats: list[StatsResult] = list(pool.map(lambda leaf: dataset.stats(leaf), leaf_paths))
465
+ sorted_stats = sorted([stat for stat in stats if stat.avg_text_length],
466
+ key=lambda stat: stat.avg_text_length or -1.0)
467
+ media_paths: set[PathTuple] = set()
468
+ if sorted_stats:
469
+ media_paths = set([sorted_stats[-1].path])
470
+
471
+ return DatasetSettings(ui=DatasetUISettings(media_paths=media_paths))
472
+
473
+
474
+ def make_parquet_id(signal: Signal,
475
+ source_path: PathTuple,
476
+ is_computed_signal: Optional[bool] = False) -> str:
477
+ """Return a unique identifier for this parquet table."""
478
+ # Don't use the VALUE_KEY as part of the parquet id to reduce the size of paths.
479
+ path = source_path[:-1] if source_path[-1] == VALUE_KEY else source_path
480
+ column_alias = '.'.join(map(str, path))
481
+ if column_alias.endswith('.*'):
482
+ # Remove the trailing .* from the column name.
483
+ column_alias = column_alias[:-2]
484
+
485
+ return f'{signal.key(is_computed_signal=is_computed_signal)}({column_alias})'
lilac/data/dataset_duckdb.py ADDED
@@ -0,0 +1,1717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The DuckDB implementation of the dataset database."""
2
+ import functools
3
+ import gc
4
+ import glob
5
+ import math
6
+ import os
7
+ import pathlib
8
+ import re
9
+ import shutil
10
+ import threading
11
+ from typing import Any, Iterable, Iterator, Optional, Sequence, Union, cast
12
+
13
+ import duckdb
14
+ import numpy as np
15
+ import pandas as pd
16
+ from pandas.api.types import is_object_dtype
17
+ from pydantic import BaseModel, validator
18
+ from typing_extensions import override
19
+
20
+ from ..auth import UserInfo
21
+ from ..batch_utils import deep_flatten, deep_unflatten
22
+ from ..embeddings.vector_store import VectorDBIndex
23
+ from ..env import data_path, env
24
+ from ..schema import (
25
+ MANIFEST_FILENAME,
26
+ PATH_WILDCARD,
27
+ TEXT_SPAN_END_FEATURE,
28
+ TEXT_SPAN_START_FEATURE,
29
+ UUID_COLUMN,
30
+ VALUE_KEY,
31
+ Bin,
32
+ DataType,
33
+ Field,
34
+ Item,
35
+ Path,
36
+ PathKey,
37
+ PathTuple,
38
+ RichData,
39
+ Schema,
40
+ SourceManifest,
41
+ column_paths_match,
42
+ is_float,
43
+ is_integer,
44
+ is_ordinal,
45
+ is_temporal,
46
+ normalize_path,
47
+ signal_type_supports_dtype,
48
+ )
49
+ from ..signals.concept_labels import ConceptLabelsSignal
50
+ from ..signals.concept_scorer import ConceptScoreSignal
51
+ from ..signals.semantic_similarity import SemanticSimilaritySignal
52
+ from ..signals.signal import (
53
+ Signal,
54
+ TextEmbeddingSignal,
55
+ VectorSignal,
56
+ get_signal_by_type,
57
+ resolve_signal,
58
+ )
59
+ from ..signals.substring_search import SubstringSignal
60
+ from ..tasks import TaskStepId, progress
61
+ from ..utils import DebugTimer, get_dataset_output_dir, log, open_file
62
+ from . import dataset
63
+ from .dataset import (
64
+ BinaryOp,
65
+ Column,
66
+ ColumnId,
67
+ Dataset,
68
+ DatasetManifest,
69
+ DatasetSettings,
70
+ FeatureListValue,
71
+ FeatureValue,
72
+ Filter,
73
+ FilterLike,
74
+ GroupsSortBy,
75
+ ListOp,
76
+ MediaResult,
77
+ Search,
78
+ SearchResultInfo,
79
+ SelectGroupsResult,
80
+ SelectRowsResult,
81
+ SelectRowsSchemaResult,
82
+ SelectRowsSchemaUDF,
83
+ SortOrder,
84
+ SortResult,
85
+ StatsResult,
86
+ UnaryOp,
87
+ column_from_identifier,
88
+ default_settings,
89
+ make_parquet_id,
90
+ )
91
+ from .dataset_utils import (
92
+ count_primitives,
93
+ create_signal_schema,
94
+ flatten_keys,
95
+ merge_schemas,
96
+ schema_contains_path,
97
+ sparse_to_dense_compute,
98
+ wrap_in_dicts,
99
+ write_embeddings_to_disk,
100
+ write_items_to_parquet,
101
+ )
102
+
103
+ UUID_INDEX_FILENAME = 'uuids.npy'
104
+
105
+ SIGNAL_MANIFEST_FILENAME = 'signal_manifest.json'
106
+ DATASET_SETTINGS_FILENAME = 'settings.json'
107
+ SOURCE_VIEW_NAME = 'source'
108
+
109
+ # Sample size for approximating the distinct count of a column.
110
+ SAMPLE_SIZE_DISTINCT_COUNT = 100_000
111
+ NUM_AUTO_BINS = 15
112
+
113
+ BINARY_OP_TO_SQL: dict[BinaryOp, str] = {
114
+ BinaryOp.EQUALS: '=',
115
+ BinaryOp.NOT_EQUAL: '!=',
116
+ BinaryOp.GREATER: '>',
117
+ BinaryOp.GREATER_EQUAL: '>=',
118
+ BinaryOp.LESS: '<',
119
+ BinaryOp.LESS_EQUAL: '<='
120
+ }
121
+
122
+
123
+ class DuckDBSearchUDF(BaseModel):
124
+ """The transformation of searches to column UDFs."""
125
+ udf: Column
126
+ search_path: PathTuple
127
+ output_path: PathTuple
128
+ sort: Optional[tuple[PathTuple, SortOrder]] = None
129
+
130
+
131
+ class DuckDBSearchUDFs(BaseModel):
132
+ """The transformation of searches to column UDFs with sorts."""
133
+ udfs: list[Column]
134
+ output_paths: list[PathTuple]
135
+ sorts: list[tuple[PathTuple, SortOrder]]
136
+
137
+
138
+ class DatasetDuckDB(Dataset):
139
+ """The DuckDB implementation of the dataset database."""
140
+
141
+ def __init__(self, namespace: str, dataset_name: str, vector_store: str = 'hnsw'):
142
+ super().__init__(namespace, dataset_name)
143
+
144
+ self.dataset_path = get_dataset_output_dir(data_path(), namespace, dataset_name)
145
+
146
+ # TODO: Infer the manifest from the parquet files so this is lighter weight.
147
+ self._source_manifest = read_source_manifest(self.dataset_path)
148
+ self._signal_manifests: list[SignalManifest] = []
149
+ self.con = duckdb.connect(database=':memory:')
150
+
151
+ # Maps a path and embedding to the vector index. This is lazily generated as needed.
152
+ self._vector_indices: dict[tuple[PathKey, str], VectorDBIndex] = {}
153
+ self.vector_store = vector_store
154
+ self._manifest_lock = threading.Lock()
155
+
156
+ # Calling settings creates the default settings JSON file if it doesn't exist.
157
+ self.settings()
158
+
159
+ @override
160
+ def delete(self) -> None:
161
+ """Deletes the dataset."""
162
+ self.con.close()
163
+ shutil.rmtree(self.dataset_path, ignore_errors=True)
164
+
165
+ def _create_view(self, view_name: str, files: list[str]) -> None:
166
+ self.con.execute(f"""
167
+ CREATE OR REPLACE VIEW {_escape_col_name(view_name)} AS (SELECT * FROM read_parquet({files}));
168
+ """)
169
+
170
+ # NOTE: This is cached, but when the latest mtime of any file in the dataset directory changes
171
+ # the results are invalidated.
172
+ @functools.cache
173
+ def _recompute_joint_table(self, latest_mtime_micro_sec: int) -> DatasetManifest:
174
+ del latest_mtime_micro_sec # This is used as the cache key.
175
+ merged_schema = self._source_manifest.data_schema.copy(deep=True)
176
+ self._signal_manifests = []
177
+ # Make a joined view of all the column groups.
178
+ self._create_view(SOURCE_VIEW_NAME,
179
+ [os.path.join(self.dataset_path, f) for f in self._source_manifest.files])
180
+
181
+ # Add the signal column groups.
182
+ for root, _, files in os.walk(self.dataset_path):
183
+ for file in files:
184
+ if not file.endswith(SIGNAL_MANIFEST_FILENAME):
185
+ continue
186
+
187
+ with open_file(os.path.join(root, file)) as f:
188
+ signal_manifest = SignalManifest.parse_raw(f.read())
189
+ self._signal_manifests.append(signal_manifest)
190
+ signal_files = [os.path.join(root, f) for f in signal_manifest.files]
191
+ if signal_files:
192
+ self._create_view(signal_manifest.parquet_id, signal_files)
193
+
194
+ merged_schema = merge_schemas([self._source_manifest.data_schema] +
195
+ [m.data_schema for m in self._signal_manifests])
196
+
197
+ # The logic below generates the following example query:
198
+ # CREATE OR REPLACE VIEW t AS (
199
+ # SELECT
200
+ # source.*,
201
+ # "parquet_id1"."root_column" AS "parquet_id1",
202
+ # "parquet_id2"."root_column" AS "parquet_id2"
203
+ # FROM source JOIN "parquet_id1" USING (uuid,) JOIN "parquet_id2" USING (uuid,)
204
+ # );
205
+ # NOTE: "root_column" for each signal is defined as the top-level column.
206
+ select_sql = ', '.join([f'{SOURCE_VIEW_NAME}.*'] + [(
207
+ f'{_escape_col_name(manifest.parquet_id)}.{_escape_col_name(_root_column(manifest))} '
208
+ f'AS {_escape_col_name(manifest.parquet_id)}')
209
+ for manifest in self._signal_manifests
210
+ if manifest.files])
211
+ join_sql = ' '.join([SOURCE_VIEW_NAME] + [
212
+ f'join {_escape_col_name(manifest.parquet_id)} using ({UUID_COLUMN},)'
213
+ for manifest in self._signal_manifests
214
+ if manifest.files
215
+ ])
216
+ view_or_table = 'TABLE'
217
+ use_views = env('DUCKDB_USE_VIEWS', 0) or 0
218
+ if int(use_views):
219
+ view_or_table = 'VIEW'
220
+ sql_cmd = f"""CREATE OR REPLACE {view_or_table} t AS (SELECT {select_sql} FROM {join_sql})"""
221
+ self.con.execute(sql_cmd)
222
+
223
+ # Get the total size of the table.
224
+ size_query = 'SELECT COUNT() as count FROM t'
225
+ size_query_result = cast(Any, self._query(size_query)[0])
226
+ num_items = cast(int, size_query_result[0])
227
+
228
+ return DatasetManifest(
229
+ namespace=self.namespace,
230
+ dataset_name=self.dataset_name,
231
+ data_schema=merged_schema,
232
+ num_items=num_items)
233
+
234
+ @override
235
+ def manifest(self) -> DatasetManifest:
236
+ # Use the latest modification time of all files under the dataset path as the cache key for
237
+ # re-computing the manifest and the joined view.
238
+ with self._manifest_lock:
239
+ all_dataset_files = glob.iglob(os.path.join(self.dataset_path, '**'), recursive=True)
240
+ latest_mtime = max(map(os.path.getmtime, all_dataset_files))
241
+ latest_mtime_micro_sec = int(latest_mtime * 1e6)
242
+ return self._recompute_joint_table(latest_mtime_micro_sec)
243
+
244
+ @override
245
+ def settings(self) -> DatasetSettings:
246
+ # Read the settings file from disk.
247
+ settings_filepath = _settings_filepath(self.namespace, self.dataset_name)
248
+ if not os.path.exists(settings_filepath):
249
+ self.update_settings(default_settings(self))
250
+
251
+ with open(settings_filepath) as f:
252
+ return DatasetSettings.parse_raw(f.read())
253
+
254
+ @override
255
+ def update_settings(self, settings: DatasetSettings) -> None:
256
+ # Write the settings file from disk.
257
+ settings_filepath = _settings_filepath(self.namespace, self.dataset_name)
258
+ with open(settings_filepath, 'w') as f:
259
+ f.write(settings.json())
260
+
261
+ def count(self, filters: Optional[list[FilterLike]] = None) -> int:
262
+ """Count the number of rows."""
263
+ raise NotImplementedError('count is not yet implemented for DuckDB.')
264
+
265
+ def _get_vector_db_index(self, embedding: str, path: PathTuple) -> VectorDBIndex:
266
+ # Refresh the manifest to make sure we have the latest signal manifests.
267
+ self.manifest()
268
+ index_key = (path, embedding)
269
+ if index_key in self._vector_indices:
270
+ return self._vector_indices[index_key]
271
+
272
+ manifests = [
273
+ m for m in self._signal_manifests
274
+ if schema_contains_path(m.data_schema, path) and m.vector_store and m.signal.name == embedding
275
+ ]
276
+ if not manifests:
277
+ raise ValueError(f'No embedding found for path {path}.')
278
+ if len(manifests) > 1:
279
+ raise ValueError(f'Multiple embeddings found for path {path}. Got: {manifests}')
280
+ manifest = manifests[0]
281
+ if not manifest.vector_store:
282
+ raise ValueError(f'Signal manifest for path {path} is not an embedding. '
283
+ f'Got signal manifest: {manifest}')
284
+
285
+ base_path = os.path.join(self.dataset_path, _signal_dir(manifest.enriched_path),
286
+ manifest.signal.name)
287
+ with DebugTimer(f'Loading vector store "{manifest.vector_store}" for "{path}"'
288
+ f' with embedding "{embedding}"'):
289
+ vector_index = VectorDBIndex(manifest.vector_store)
290
+ vector_index.load(base_path)
291
+ # Cache the vector index.
292
+ self._vector_indices[index_key] = vector_index
293
+ return vector_index
294
+
295
+ @override
296
+ def compute_signal(self,
297
+ signal: Signal,
298
+ leaf_path: Path,
299
+ task_step_id: Optional[TaskStepId] = None) -> None:
300
+ if isinstance(signal, TextEmbeddingSignal):
301
+ return self.compute_embedding(signal.name, leaf_path, task_step_id)
302
+ source_path = normalize_path(leaf_path)
303
+ manifest = self.manifest()
304
+
305
+ if task_step_id is None:
306
+ # Make a dummy task step so we report progress via tqdm.
307
+ task_step_id = ('', 0)
308
+
309
+ # The manifest may have changed after computing the dependencies.
310
+ manifest = self.manifest()
311
+
312
+ signal_col = Column(path=source_path, alias='value', signal_udf=signal)
313
+ select_rows_result = self.select_rows([signal_col],
314
+ task_step_id=task_step_id,
315
+ resolve_span=True)
316
+ df = select_rows_result.df()
317
+ values = df['value']
318
+
319
+ enriched_path = _col_destination_path(signal_col, is_computed_signal=True)
320
+ spec = _split_path_into_subpaths_of_lists(enriched_path)
321
+ output_dir = os.path.join(self.dataset_path, _signal_dir(enriched_path))
322
+ signal_schema = create_signal_schema(signal, source_path, manifest.data_schema)
323
+ enriched_signal_items = cast(Iterable[Item], wrap_in_dicts(values, spec))
324
+ for uuid, item in zip(df[UUID_COLUMN], enriched_signal_items):
325
+ item[UUID_COLUMN] = uuid
326
+
327
+ enriched_signal_items = list(enriched_signal_items)
328
+ parquet_filename, _ = write_items_to_parquet(
329
+ items=enriched_signal_items,
330
+ output_dir=output_dir,
331
+ schema=signal_schema,
332
+ filename_prefix='data',
333
+ shard_index=0,
334
+ num_shards=1)
335
+
336
+ signal_manifest = SignalManifest(
337
+ files=[parquet_filename],
338
+ data_schema=signal_schema,
339
+ signal=signal,
340
+ enriched_path=source_path,
341
+ parquet_id=make_parquet_id(signal, source_path, is_computed_signal=True))
342
+ signal_manifest_filepath = os.path.join(output_dir, SIGNAL_MANIFEST_FILENAME)
343
+ with open_file(signal_manifest_filepath, 'w') as f:
344
+ f.write(signal_manifest.json(exclude_none=True, indent=2))
345
+ log(f'Wrote signal output to {output_dir}')
346
+
347
+ @override
348
+ def compute_embedding(self,
349
+ embedding: str,
350
+ path: Path,
351
+ task_step_id: Optional[TaskStepId] = None) -> None:
352
+ source_path = normalize_path(path)
353
+ manifest = self.manifest()
354
+
355
+ if task_step_id is None:
356
+ # Make a dummy task step so we report progress via tqdm.
357
+ task_step_id = ('', 0)
358
+
359
+ signal = get_signal_by_type(embedding, TextEmbeddingSignal)()
360
+ signal_col = Column(path=source_path, alias='value', signal_udf=signal)
361
+ select_rows_result = self.select_rows([signal_col],
362
+ task_step_id=task_step_id,
363
+ resolve_span=True)
364
+ df = select_rows_result.df()
365
+ values = df['value']
366
+
367
+ enriched_path = _col_destination_path(signal_col, is_computed_signal=True)
368
+ output_dir = os.path.join(self.dataset_path, _signal_dir(enriched_path))
369
+ signal_schema = create_signal_schema(signal, source_path, manifest.data_schema)
370
+
371
+ write_embeddings_to_disk(
372
+ vector_store=self.vector_store,
373
+ uuids=df[UUID_COLUMN],
374
+ signal_items=values,
375
+ output_dir=output_dir)
376
+
377
+ del select_rows_result, df, values
378
+ gc.collect()
379
+
380
+ signal_manifest = SignalManifest(
381
+ files=[],
382
+ data_schema=signal_schema,
383
+ signal=signal,
384
+ enriched_path=source_path,
385
+ parquet_id=make_parquet_id(signal, source_path, is_computed_signal=True),
386
+ vector_store=self.vector_store)
387
+ signal_manifest_filepath = os.path.join(output_dir, SIGNAL_MANIFEST_FILENAME)
388
+
389
+ with open_file(signal_manifest_filepath, 'w') as f:
390
+ f.write(signal_manifest.json(exclude_none=True, indent=2))
391
+ log(f'Wrote embedding index to {output_dir}')
392
+
393
+ @override
394
+ def delete_signal(self, signal_path: Path) -> None:
395
+ signal_path = normalize_path(signal_path)
396
+ manifest = self.manifest()
397
+ if not manifest.data_schema.has_field(signal_path):
398
+ raise ValueError(f'Unknown signal path: {signal_path}')
399
+
400
+ output_dir = os.path.join(self.dataset_path, _signal_dir(signal_path))
401
+ shutil.rmtree(output_dir, ignore_errors=True)
402
+
403
+ def _validate_filters(self, filters: Sequence[Filter], col_aliases: dict[str, PathTuple],
404
+ manifest: DatasetManifest) -> None:
405
+ for filter in filters:
406
+ if filter.path[0] in col_aliases:
407
+ # This is a filter on a column alias, which is always allowed.
408
+ continue
409
+
410
+ current_field = Field(fields=manifest.data_schema.fields)
411
+ for path_part in filter.path:
412
+ if path_part == VALUE_KEY:
413
+ if not current_field.dtype:
414
+ raise ValueError(f'Unable to filter on path {filter.path}. The field has no value.')
415
+ continue
416
+ if current_field.fields:
417
+ if path_part not in current_field.fields:
418
+ raise ValueError(f'Unable to filter on path {filter.path}. '
419
+ f'Path part "{path_part}" not found in the dataset.')
420
+ current_field = current_field.fields[str(path_part)]
421
+ continue
422
+ elif current_field.repeated_field:
423
+ current_field = current_field.repeated_field
424
+ continue
425
+ else:
426
+ raise ValueError(f'Unable to filter on path {filter.path}. '
427
+ f'Path part "{path_part}" is not defined on a primitive value.')
428
+
429
+ while current_field.repeated_field:
430
+ current_field = current_field.repeated_field
431
+ filter.path = (*filter.path, PATH_WILDCARD)
432
+
433
+ if not current_field.dtype:
434
+ raise ValueError(f'Unable to filter on path {filter.path}. The field has no value.')
435
+
436
+ def _validate_udfs(self, udf_cols: Sequence[Column], source_schema: Schema) -> None:
437
+ for col in udf_cols:
438
+ path = col.path
439
+
440
+ # Signal transforms must operate on a leaf field.
441
+ leaf = source_schema.leafs.get(path)
442
+ if not leaf or not leaf.dtype:
443
+ raise ValueError(f'Leaf "{path}" not found in dataset. '
444
+ 'Signal transforms must operate on a leaf field.')
445
+
446
+ # Signal transforms must have the same dtype as the leaf field.
447
+ signal = cast(Signal, col.signal_udf)
448
+ if not signal_type_supports_dtype(signal.input_type, leaf.dtype):
449
+ raise ValueError(f'Leaf "{path}" has dtype "{leaf.dtype}" which is not supported '
450
+ f'by "{signal.key()}" with signal input type "{signal.input_type}".')
451
+
452
+ def _validate_selection(self, columns: Sequence[Column], select_schema: Schema) -> None:
453
+ # Validate all the columns and make sure they exist in the `select_schema`.
454
+ for column in columns:
455
+ current_field = Field(fields=select_schema.fields)
456
+ path = column.path
457
+ for path_part in path:
458
+ if path_part == VALUE_KEY:
459
+ if not current_field.dtype:
460
+ raise ValueError(f'Unable to select path {path}. The field that has no value.')
461
+ continue
462
+ if current_field.fields:
463
+ if path_part not in current_field.fields:
464
+ raise ValueError(f'Unable to select path {path}. '
465
+ f'Path part "{path_part}" not found in the dataset.')
466
+ current_field = current_field.fields[path_part]
467
+ continue
468
+ elif current_field.repeated_field:
469
+ if path_part.isdigit():
470
+ raise ValueError(f'Unable to select path {path}. Selecting a specific index of '
471
+ 'a repeated field is currently not supported.')
472
+ if path_part != PATH_WILDCARD:
473
+ raise ValueError(f'Unable to select path {path}. '
474
+ f'Path part "{path_part}" should be a wildcard.')
475
+ current_field = current_field.repeated_field
476
+ elif not current_field.dtype:
477
+ raise ValueError(f'Unable to select path {path}. '
478
+ f'Path part "{path_part}" is not defined on a primitive value.')
479
+
480
+ def _validate_columns(self, columns: Sequence[Column], source_schema: Schema,
481
+ select_schema: Schema) -> None:
482
+ udf_cols = [col for col in columns if col.signal_udf]
483
+ self._validate_udfs(udf_cols, source_schema)
484
+ self._validate_selection(columns, select_schema)
485
+
486
+ def _validate_sort_path(self, path: PathTuple, schema: Schema) -> None:
487
+ current_field = Field(fields=schema.fields)
488
+ for path_part in path:
489
+ if path_part == VALUE_KEY:
490
+ if not current_field.dtype:
491
+ raise ValueError(f'Unable to sort by path {path}. The field that has no value.')
492
+ continue
493
+ if current_field.fields:
494
+ if path_part not in current_field.fields:
495
+ raise ValueError(f'Unable to sort by path {path}. '
496
+ f'Path part "{path_part}" not found in the dataset.')
497
+ current_field = current_field.fields[path_part]
498
+ continue
499
+ elif current_field.repeated_field:
500
+ if path_part.isdigit():
501
+ raise ValueError(f'Unable to sort by path {path}. Selecting a specific index of '
502
+ 'a repeated field is currently not supported.')
503
+ if path_part != PATH_WILDCARD:
504
+ raise ValueError(f'Unable to sort by path {path}. '
505
+ f'Path part "{path_part}" should be a wildcard.')
506
+ current_field = current_field.repeated_field
507
+ elif not current_field.dtype:
508
+ raise ValueError(f'Unable to sort by path {path}. '
509
+ f'Path part "{path_part}" is not defined on a primitive value.')
510
+ if not current_field.dtype:
511
+ raise ValueError(f'Unable to sort by path {path}. The field has no value.')
512
+
513
+ @override
514
+ def stats(self, leaf_path: Path) -> StatsResult:
515
+ if not leaf_path:
516
+ raise ValueError('leaf_path must be provided')
517
+ path = normalize_path(leaf_path)
518
+ manifest = self.manifest()
519
+ leaf = manifest.data_schema.get_field(path)
520
+ # Find the inner-most leaf in case this field is repeated.
521
+ while leaf.repeated_field:
522
+ leaf = leaf.repeated_field
523
+ path = (*path, PATH_WILDCARD)
524
+
525
+ if not leaf.dtype:
526
+ raise ValueError(f'Leaf "{path}" not found in dataset')
527
+
528
+ duckdb_path = self._leaf_path_to_duckdb_path(path, manifest.data_schema)
529
+ inner_select = _select_sql(
530
+ duckdb_path, flatten=True, unnest=True, span_from=self._get_span_from(path, manifest))
531
+
532
+ # Compute approximate count by sampling the data to avoid OOM.
533
+ sample_size = SAMPLE_SIZE_DISTINCT_COUNT
534
+ avg_length_query = ''
535
+ if leaf.dtype == DataType.STRING:
536
+ avg_length_query = ', avg(length(val)) as avgTextLength'
537
+
538
+ row: Optional[tuple[int, ...]] = None
539
+ if leaf.dtype == DataType.BOOLEAN:
540
+ approx_count_distinct = 2
541
+ else:
542
+ approx_count_query = f"""
543
+ SELECT approx_count_distinct(val) as approxCountDistinct {avg_length_query}
544
+ FROM (SELECT {inner_select} AS val FROM t LIMIT {sample_size});
545
+ """
546
+ row = self._query(approx_count_query)[0]
547
+ approx_count_distinct = row[0]
548
+
549
+ total_count_query = f'SELECT count(val) FROM (SELECT {inner_select} as val FROM t)'
550
+ total_count = self._query(total_count_query)[0][0]
551
+
552
+ if leaf.dtype != DataType.BOOLEAN:
553
+ # Adjust the counts for the sample size.
554
+ factor = max(1, total_count / sample_size)
555
+ approx_count_distinct = round(approx_count_distinct * factor)
556
+
557
+ result = StatsResult(
558
+ path=path, total_count=total_count, approx_count_distinct=approx_count_distinct)
559
+
560
+ if leaf.dtype == DataType.STRING and row:
561
+ result.avg_text_length = row[1]
562
+
563
+ # Compute min/max values for ordinal leafs, without sampling the data.
564
+ if is_ordinal(leaf.dtype):
565
+ min_max_query = f"""
566
+ SELECT MIN(val) AS minVal, MAX(val) AS maxVal
567
+ FROM (SELECT {inner_select} as val FROM t)
568
+ {'WHERE NOT isnan(val)' if is_float(leaf.dtype) else ''}
569
+ """
570
+ row = self._query(min_max_query)[0]
571
+ result.min_val, result.max_val = row
572
+
573
+ return result
574
+
575
+ @override
576
+ def select_groups(
577
+ self,
578
+ leaf_path: Path,
579
+ filters: Optional[Sequence[FilterLike]] = None,
580
+ sort_by: Optional[GroupsSortBy] = GroupsSortBy.COUNT,
581
+ sort_order: Optional[SortOrder] = SortOrder.DESC,
582
+ limit: Optional[int] = None,
583
+ bins: Optional[Union[Sequence[Bin], Sequence[float]]] = None) -> SelectGroupsResult:
584
+ if not leaf_path:
585
+ raise ValueError('leaf_path must be provided')
586
+ path = normalize_path(leaf_path)
587
+ manifest = self.manifest()
588
+ leaf = manifest.data_schema.get_field(path)
589
+ # Find the inner-most leaf in case this field is repeated.
590
+ while leaf.repeated_field:
591
+ leaf = leaf.repeated_field
592
+ path = (*path, PATH_WILDCARD)
593
+
594
+ if not leaf.dtype:
595
+ raise ValueError(f'Leaf "{path}" not found in dataset')
596
+
597
+ inner_val = 'inner_val'
598
+ outer_select = inner_val
599
+ # Normalize the bins to be `list[Bin]`.
600
+ named_bins = _normalize_bins(bins or leaf.bins)
601
+ stats = self.stats(leaf_path)
602
+
603
+ leaf_is_float = is_float(leaf.dtype)
604
+ leaf_is_integer = is_integer(leaf.dtype)
605
+ if not leaf.categorical and (leaf_is_float or leaf_is_integer):
606
+ if named_bins is None:
607
+ # Auto-bin.
608
+ named_bins = _auto_bins(stats, NUM_AUTO_BINS)
609
+
610
+ sql_bounds = []
611
+ for label, start, end in named_bins:
612
+ if start is None:
613
+ start = cast(float, "'-Infinity'")
614
+ if end is None:
615
+ end = cast(float, "'Infinity'")
616
+ sql_bounds.append(f"('{label}', {start}, {end})")
617
+
618
+ bin_index_col = 'col0'
619
+ bin_min_col = 'col1'
620
+ bin_max_col = 'col2'
621
+ is_nan_filter = f'NOT isnan({inner_val}) AND' if leaf_is_float else ''
622
+
623
+ # We cast the field to `double` so binning works for both `float` and `int` fields.
624
+ outer_select = f"""(
625
+ SELECT {bin_index_col} FROM (
626
+ VALUES {', '.join(sql_bounds)}
627
+ ) WHERE {is_nan_filter}
628
+ {inner_val}::DOUBLE >= {bin_min_col} AND {inner_val}::DOUBLE < {bin_max_col}
629
+ )"""
630
+ else:
631
+ if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
632
+ return SelectGroupsResult(too_many_distinct=True, counts=[], bins=named_bins)
633
+
634
+ count_column = 'count'
635
+ value_column = 'value'
636
+
637
+ limit_query = f'LIMIT {limit}' if limit else ''
638
+ duckdb_path = self._leaf_path_to_duckdb_path(path, manifest.data_schema)
639
+ inner_select = _select_sql(
640
+ duckdb_path, flatten=True, unnest=True, span_from=self._get_span_from(path, manifest))
641
+
642
+ filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
643
+ filter_queries = self._create_where(manifest, filters, searches=[])
644
+
645
+ where_query = ''
646
+ if filter_queries:
647
+ where_query = f"WHERE {' AND '.join(filter_queries)}"
648
+
649
+ query = f"""
650
+ SELECT {outer_select} AS {value_column}, COUNT() AS {count_column}
651
+ FROM (SELECT {inner_select} AS {inner_val} FROM t {where_query})
652
+ GROUP BY {value_column}
653
+ ORDER BY {sort_by} {sort_order}
654
+ {limit_query}
655
+ """
656
+ df = self._query_df(query)
657
+ counts = list(df.itertuples(index=False, name=None))
658
+ if is_temporal(leaf.dtype):
659
+ # Replace any NaT with None and pd.Timestamp to native datetime objects.
660
+ counts = [(None if pd.isnull(val) else val.to_pydatetime(), count) for val, count in counts]
661
+ return SelectGroupsResult(too_many_distinct=False, counts=counts, bins=named_bins)
662
+
663
+ def _topk_udf_to_sort_by(
664
+ self,
665
+ udf_columns: list[Column],
666
+ sort_by: list[PathTuple],
667
+ limit: Optional[int],
668
+ sort_order: Optional[SortOrder],
669
+ ) -> Optional[Column]:
670
+ if (sort_order != SortOrder.DESC) or (not limit) or (not sort_by):
671
+ return None
672
+ if len(sort_by) < 1:
673
+ return None
674
+ primary_sort_by = sort_by[0]
675
+ udf_cols_to_sort_by = [
676
+ udf_col for udf_col in udf_columns if udf_col.alias == primary_sort_by[0] or
677
+ _path_contains(_col_destination_path(udf_col), primary_sort_by)
678
+ ]
679
+ if not udf_cols_to_sort_by:
680
+ return None
681
+ udf_col = udf_cols_to_sort_by[0]
682
+ if udf_col.signal_udf and not isinstance(udf_col.signal_udf, VectorSignal):
683
+ return None
684
+ return udf_col
685
+
686
+ def _normalize_columns(self, columns: Optional[Sequence[ColumnId]],
687
+ schema: Schema) -> list[Column]:
688
+ """Normalizes the columns to a list of `Column` objects."""
689
+ cols = [column_from_identifier(col) for col in columns or []]
690
+ star_in_cols = any(col.path == ('*',) for col in cols)
691
+ if not cols or star_in_cols:
692
+ # Select all columns.
693
+ cols.extend([Column((name,)) for name in schema.fields.keys()])
694
+ if star_in_cols:
695
+ cols = [col for col in cols if col.path != ('*',)]
696
+ return cols
697
+
698
+ def _merge_sorts(self, search_udfs: list[DuckDBSearchUDF], sort_by: Optional[Sequence[Path]],
699
+ sort_order: Optional[SortOrder]) -> list[SortResult]:
700
+ # True when the user has explicitly sorted by the alias of a search UDF (e.g. in ASC order).
701
+ is_explicit_search_sort = False
702
+ for sort_by_path in sort_by or []:
703
+ for search_udf in search_udfs:
704
+ if column_paths_match(sort_by_path, search_udf.output_path):
705
+ is_explicit_search_sort = True
706
+ break
707
+
708
+ sort_results: list[SortResult] = []
709
+ if sort_by and not is_explicit_search_sort:
710
+ if not sort_order:
711
+ raise ValueError('`sort_order` is required when `sort_by` is specified.')
712
+ # If the user has explicitly set a sort by, and it's not a search UDF alias, override.
713
+ sort_results = [
714
+ SortResult(path=normalize_path(sort_by), order=sort_order) for sort_by in sort_by if sort_by
715
+ ]
716
+ else:
717
+ search_udfs_with_sort = [search_udf for search_udf in search_udfs if search_udf.sort]
718
+ if search_udfs_with_sort:
719
+ # Override the sort by the last search sort order when the user hasn't provided an
720
+ # explicit sort order.
721
+ last_search_udf = search_udfs_with_sort[-1]
722
+ assert last_search_udf.sort, 'Expected search UDFs with sort to have a sort.'
723
+ udf_sort_path, udf_sort_order = last_search_udf.sort
724
+ sort_results = [
725
+ SortResult(
726
+ path=udf_sort_path,
727
+ order=sort_order or udf_sort_order,
728
+ search_index=len(search_udfs_with_sort) - 1)
729
+ ]
730
+
731
+ return sort_results
732
+
733
+ @override
734
+ def select_rows(self,
735
+ columns: Optional[Sequence[ColumnId]] = None,
736
+ searches: Optional[Sequence[Search]] = None,
737
+ filters: Optional[Sequence[FilterLike]] = None,
738
+ sort_by: Optional[Sequence[Path]] = None,
739
+ sort_order: Optional[SortOrder] = SortOrder.DESC,
740
+ limit: Optional[int] = None,
741
+ offset: Optional[int] = 0,
742
+ task_step_id: Optional[TaskStepId] = None,
743
+ resolve_span: bool = False,
744
+ combine_columns: bool = False,
745
+ user: Optional[UserInfo] = None) -> SelectRowsResult:
746
+ manifest = self.manifest()
747
+ cols = self._normalize_columns(columns, manifest.data_schema)
748
+
749
+ # Always return the UUID column.
750
+ col_paths = [col.path for col in cols]
751
+ if (UUID_COLUMN,) not in col_paths:
752
+ cols.append(column_from_identifier(UUID_COLUMN))
753
+
754
+ schema = manifest.data_schema
755
+
756
+ if combine_columns:
757
+ schema = self.select_rows_schema(
758
+ columns, sort_by, sort_order, searches, combine_columns=True).data_schema
759
+
760
+ self._validate_columns(cols, manifest.data_schema, schema)
761
+ self._normalize_searches(searches, manifest)
762
+ search_udfs = self._search_udfs(searches, manifest)
763
+ cols.extend([search_udf.udf for search_udf in search_udfs])
764
+ udf_columns = [col for col in cols if col.signal_udf]
765
+
766
+ # Set extra information on any concept signals.
767
+ for udf_col in udf_columns:
768
+ if isinstance(udf_col.signal_udf, (ConceptScoreSignal, ConceptLabelsSignal)):
769
+ # Concept are access controlled so we tell it about the user.
770
+ udf_col.signal_udf.set_user(user)
771
+
772
+ # Decide on the exact sorting order.
773
+ sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
774
+ sort_by = cast(list[PathTuple],
775
+ [(sort.alias,) if sort.alias else sort.path for sort in sort_results])
776
+ # Choose the first sort order as we only support a single sort order for now.
777
+ sort_order = sort_results[0].order if sort_results else None
778
+
779
+ col_aliases: dict[str, PathTuple] = {col.alias: col.path for col in cols if col.alias}
780
+ udf_aliases: dict[str, PathTuple] = {
781
+ col.alias: col.path for col in cols if col.signal_udf and col.alias
782
+ }
783
+ path_to_udf_col_name: dict[PathTuple, str] = {}
784
+ for col in cols:
785
+ if col.signal_udf:
786
+ alias = col.alias or _unique_alias(col)
787
+ dest_path = _col_destination_path(col)
788
+ path_to_udf_col_name[dest_path] = alias
789
+
790
+ # Filtering and searching.
791
+ where_query = ''
792
+ filters, udf_filters = self._normalize_filters(filters, col_aliases, udf_aliases, manifest)
793
+ filter_queries = self._create_where(manifest, filters, searches)
794
+ if filter_queries:
795
+ where_query = f"WHERE {' AND '.join(filter_queries)}"
796
+
797
+ total_num_rows = manifest.num_items
798
+ con = self.con.cursor()
799
+
800
+ topk_udf_col = self._topk_udf_to_sort_by(udf_columns, sort_by, limit, sort_order)
801
+ if topk_udf_col:
802
+ path_keys: Optional[list[PathKey]] = None
803
+ if where_query:
804
+ # If there are filters, we need to send UUIDs to the top k query.
805
+ df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
806
+ total_num_rows = len(df)
807
+ # Convert UUIDs to path keys.
808
+ path_keys = [(uuid,) for uuid in df[UUID_COLUMN]]
809
+
810
+ if path_keys is not None and len(path_keys) == 0:
811
+ where_query = 'WHERE false'
812
+ else:
813
+ topk_signal = cast(VectorSignal, topk_udf_col.signal_udf)
814
+ # The input is an embedding.
815
+ vector_index = self._get_vector_db_index(topk_signal.embedding, topk_udf_col.path)
816
+ k = (limit or 0) + (offset or 0)
817
+ with DebugTimer(f'Compute topk on "{topk_udf_col.path}" using embedding '
818
+ f'"{topk_signal.embedding}" with vector store "{self.vector_store}"'):
819
+ topk = topk_signal.vector_compute_topk(k, vector_index, path_keys)
820
+ topk_uuids = list(dict.fromkeys([cast(str, uuid) for (uuid, *_), _ in topk]))
821
+ # Update the offset to account for the number of unique UUIDs.
822
+ offset = len(dict.fromkeys([cast(str, uuid) for (uuid, *_), _ in topk[:offset]]))
823
+
824
+ # Ignore all the other filters and filter DuckDB results only by the top k UUIDs.
825
+ uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
826
+ filter_query = self._create_where(manifest, [uuid_filter])[0]
827
+ where_query = f'WHERE {filter_query}'
828
+
829
+ # Map a final column name to a list of temporary namespaced column names that need to be merged.
830
+ columns_to_merge: dict[str, dict[str, Column]] = {}
831
+ temp_column_to_offset_column: dict[str, tuple[str, Field]] = {}
832
+ select_queries: list[str] = []
833
+
834
+ for column in cols:
835
+ path = column.path
836
+ # If the signal is vector-based, we don't need to select the actual data, just the uuids
837
+ # plus an arbitrarily nested array of `None`s`.
838
+ empty = bool(column.signal_udf and schema.get_field(path).dtype == DataType.EMBEDDING)
839
+
840
+ select_sqls: list[str] = []
841
+ final_col_name = column.alias or _unique_alias(column)
842
+ if final_col_name not in columns_to_merge:
843
+ columns_to_merge[final_col_name] = {}
844
+
845
+ duckdb_paths = self._column_to_duckdb_paths(column, schema, combine_columns)
846
+ span_from = self._get_span_from(path, manifest) if resolve_span or column.signal_udf else None
847
+
848
+ for parquet_id, duckdb_path in duckdb_paths:
849
+ sql = _select_sql(
850
+ duckdb_path, flatten=False, unnest=False, empty=empty, span_from=span_from)
851
+ temp_column_name = (
852
+ final_col_name if len(duckdb_paths) == 1 else f'{final_col_name}/{parquet_id}')
853
+ select_sqls.append(f'{sql} AS {_escape_string_literal(temp_column_name)}')
854
+ columns_to_merge[final_col_name][temp_column_name] = column
855
+
856
+ if column.signal_udf and span_from and _schema_has_spans(column.signal_udf.fields()):
857
+ sql = _select_sql(duckdb_path, flatten=False, unnest=False, empty=empty, span_from=None)
858
+ temp_offset_column_name = f'{temp_column_name}/offset'
859
+ temp_offset_column_name = temp_offset_column_name.replace("'", "\\'")
860
+ select_sqls.append(f'{sql} AS {_escape_string_literal(temp_offset_column_name)}')
861
+ temp_column_to_offset_column[temp_column_name] = (temp_offset_column_name,
862
+ column.signal_udf.fields())
863
+
864
+ # `select_sqls` can be empty if this column points to a path that will be created by a UDF.
865
+ if select_sqls:
866
+ select_queries.append(', '.join(select_sqls))
867
+
868
+ sort_sql_before_udf: list[str] = []
869
+ sort_sql_after_udf: list[str] = []
870
+
871
+ for path in sort_by:
872
+ # We only allow sorting by nodes with a value.
873
+ sort_path = path
874
+ first_subpath = str(path[0])
875
+ rest_of_path = path[1:]
876
+ signal_alias = '.'.join(map(str, path))
877
+
878
+ udf_path = _path_to_udf_duckdb_path(path, path_to_udf_col_name)
879
+ if not udf_path:
880
+ # Re-route the path if it starts with an alias by pointing it to the actual path.
881
+ if first_subpath in col_aliases:
882
+ path = (*col_aliases[first_subpath], *rest_of_path)
883
+ self._validate_sort_path(path, schema)
884
+ path = self._leaf_path_to_duckdb_path(path, schema)
885
+ else:
886
+ path = udf_path
887
+
888
+ sort_sql = _select_sql(path, flatten=True, unnest=False)
889
+ has_repeated_field = any(subpath == PATH_WILDCARD for subpath in path)
890
+ if has_repeated_field:
891
+ sort_sql = (f'list_min({sort_sql})'
892
+ if sort_order == SortOrder.ASC else f'list_max({sort_sql})')
893
+
894
+ # Separate sort columns into two groups: those that need to be sorted before and after UDFs.
895
+ if udf_path:
896
+ sort_sql_after_udf.append(sort_sql)
897
+ else:
898
+ sort_sql_before_udf.append(sort_sql)
899
+
900
+ order_query = ''
901
+ if sort_sql_before_udf:
902
+ order_query = (f'ORDER BY {", ".join(sort_sql_before_udf)} '
903
+ f'{cast(SortOrder, sort_order).value}')
904
+
905
+ limit_query = ''
906
+ if limit:
907
+ if topk_udf_col:
908
+ limit_query = f'LIMIT {limit + (offset or 0)}'
909
+ elif sort_sql_after_udf:
910
+ limit_query = ''
911
+ else:
912
+ limit_query = f'LIMIT {limit} OFFSET {offset or 0}'
913
+
914
+ if not topk_udf_col and where_query:
915
+ total_num_rows = cast(tuple,
916
+ con.execute(f'SELECT COUNT(*) FROM t {where_query}').fetchone())[0]
917
+
918
+ # Fetch the data from DuckDB.
919
+ df = con.execute(f"""
920
+ SELECT {', '.join(select_queries)} FROM t
921
+ {where_query}
922
+ {order_query}
923
+ {limit_query}
924
+ """).df()
925
+ df = _replace_nan_with_none(df)
926
+
927
+ # Run UDFs on the transformed columns.
928
+ for udf_col in udf_columns:
929
+ signal = cast(Signal, udf_col.signal_udf)
930
+ signal_alias = udf_col.alias or _unique_alias(udf_col)
931
+ temp_signal_cols = columns_to_merge[signal_alias]
932
+ if len(temp_signal_cols) != 1:
933
+ raise ValueError(
934
+ f'Unable to compute signal {signal.name}. Signal UDFs only operate on leafs, but got '
935
+ f'{len(temp_signal_cols)} underlying columns that contain data related to {udf_col.path}.'
936
+ )
937
+ signal_column = list(temp_signal_cols.keys())[0]
938
+ input = df[signal_column]
939
+
940
+ with DebugTimer(f'Computing signal "{signal.name}"'):
941
+ signal.setup()
942
+
943
+ if isinstance(signal, VectorSignal):
944
+ embedding_signal = signal
945
+ vector_store = self._get_vector_db_index(embedding_signal.embedding, udf_col.path)
946
+ flat_keys = list(flatten_keys(df[UUID_COLUMN], input))
947
+ signal_out = sparse_to_dense_compute(
948
+ iter(flat_keys), lambda keys: embedding_signal.vector_compute(keys, vector_store))
949
+ # Add progress.
950
+ if task_step_id is not None:
951
+ signal_out = progress(
952
+ signal_out,
953
+ task_step_id=task_step_id,
954
+ estimated_len=len(flat_keys),
955
+ step_description=f'Computing {signal.key()}')
956
+ df[signal_column] = deep_unflatten(signal_out, input)
957
+ else:
958
+ num_rich_data = count_primitives(input)
959
+ flat_input = cast(Iterator[Optional[RichData]], deep_flatten(input))
960
+ signal_out = sparse_to_dense_compute(
961
+ flat_input, lambda x: signal.compute(cast(Iterable[RichData], x)))
962
+ # Add progress.
963
+ if task_step_id is not None:
964
+ signal_out = progress(
965
+ signal_out,
966
+ task_step_id=task_step_id,
967
+ estimated_len=num_rich_data,
968
+ step_description=f'Computing {signal.key()}')
969
+ signal_out_list = list(signal_out)
970
+ if signal_column in temp_column_to_offset_column:
971
+ offset_column_name, field = temp_column_to_offset_column[signal_column]
972
+ nested_spans: Iterable[Item] = df[offset_column_name]
973
+ flat_spans = deep_flatten(nested_spans)
974
+ for span, item in zip(flat_spans, signal_out_list):
975
+ _offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
976
+
977
+ if len(signal_out_list) != num_rich_data:
978
+ raise ValueError(
979
+ f'The signal generated {len(signal_out_list)} values but the input data had '
980
+ f"{num_rich_data} values. This means the signal either didn't generate a "
981
+ '"None" for a sparse output, or generated too many items.')
982
+
983
+ df[signal_column] = deep_unflatten(signal_out_list, input)
984
+
985
+ signal.teardown()
986
+ if not df.empty and (udf_filters or sort_sql_after_udf):
987
+ # Re-upload the udf outputs to duckdb so we can filter/sort on them.
988
+ rel = con.from_df(df)
989
+
990
+ if udf_filters:
991
+ udf_filter_queries = self._create_where(manifest, udf_filters)
992
+ if udf_filter_queries:
993
+ rel = rel.filter(' AND '.join(udf_filter_queries))
994
+ total_num_rows = cast(tuple, rel.count('*').fetchone())[0]
995
+
996
+ if sort_sql_after_udf:
997
+ if not sort_order:
998
+ raise ValueError('`sort_order` is required when `sort_by` is specified.')
999
+ rel = rel.order(f'{", ".join(sort_sql_after_udf)} {sort_order.value}')
1000
+
1001
+ if limit:
1002
+ rel = rel.limit(limit, offset or 0)
1003
+
1004
+ df = _replace_nan_with_none(rel.df())
1005
+
1006
+ if combine_columns:
1007
+ all_columns: dict[str, Column] = {}
1008
+ for col_dict in columns_to_merge.values():
1009
+ all_columns.update(col_dict)
1010
+ columns_to_merge = {'*': all_columns}
1011
+
1012
+ for offset_column, _ in temp_column_to_offset_column.values():
1013
+ del df[offset_column]
1014
+
1015
+ for final_col_name, temp_columns in columns_to_merge.items():
1016
+ for temp_col_name, column in temp_columns.items():
1017
+ if combine_columns:
1018
+ dest_path = _col_destination_path(column)
1019
+ spec = _split_path_into_subpaths_of_lists(dest_path)
1020
+ df[temp_col_name] = wrap_in_dicts(df[temp_col_name], spec)
1021
+
1022
+ # If the temp col name is the same as the final name, we can skip merging. This happens when
1023
+ # we select a source leaf column.
1024
+ if temp_col_name == final_col_name:
1025
+ continue
1026
+
1027
+ if final_col_name not in df:
1028
+ df[final_col_name] = df[temp_col_name]
1029
+ else:
1030
+ df[final_col_name] = merge_series(df[final_col_name], df[temp_col_name])
1031
+ del df[temp_col_name]
1032
+
1033
+ con.close()
1034
+
1035
+ if combine_columns:
1036
+ # Since we aliased every column to `*`, the object with have only '*' as the key. We need to
1037
+ # elevate the all the columns under '*'.
1038
+ df = pd.DataFrame.from_records(df['*'])
1039
+
1040
+ return SelectRowsResult(df, total_num_rows)
1041
+
1042
+ @override
1043
+ def select_rows_schema(self,
1044
+ columns: Optional[Sequence[ColumnId]] = None,
1045
+ sort_by: Optional[Sequence[Path]] = None,
1046
+ sort_order: Optional[SortOrder] = None,
1047
+ searches: Optional[Sequence[Search]] = None,
1048
+ combine_columns: bool = False) -> SelectRowsSchemaResult:
1049
+ """Returns the schema of the result of `select_rows` above with the same arguments."""
1050
+ if not combine_columns:
1051
+ raise NotImplementedError(
1052
+ 'select_rows_schema with combine_columns=False is not yet supported.')
1053
+ manifest = self.manifest()
1054
+ cols = self._normalize_columns(columns, manifest.data_schema)
1055
+
1056
+ # Always return the UUID column.
1057
+ col_paths = [col.path for col in cols]
1058
+ if (UUID_COLUMN,) not in col_paths:
1059
+ cols.append(column_from_identifier(UUID_COLUMN))
1060
+
1061
+ self._normalize_searches(searches, manifest)
1062
+ search_udfs = self._search_udfs(searches, manifest)
1063
+ cols.extend([search_udf.udf for search_udf in search_udfs])
1064
+
1065
+ udfs: list[SelectRowsSchemaUDF] = []
1066
+ col_schemas: list[Schema] = []
1067
+ for col in cols:
1068
+ dest_path = _col_destination_path(col)
1069
+ if col.signal_udf:
1070
+ udfs.append(SelectRowsSchemaUDF(path=dest_path, alias=col.alias))
1071
+ field = col.signal_udf.fields()
1072
+ field.signal = col.signal_udf.dict()
1073
+ elif manifest.data_schema.has_field(dest_path):
1074
+ field = manifest.data_schema.get_field(dest_path)
1075
+ else:
1076
+ # This column might refer to an output of a udf. We postpone validation to later.
1077
+ continue
1078
+ col_schemas.append(_make_schema_from_path(dest_path, field))
1079
+
1080
+ sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
1081
+
1082
+ search_results = [
1083
+ SearchResultInfo(search_path=search_udf.search_path, result_path=search_udf.output_path)
1084
+ for search_udf in search_udfs
1085
+ ]
1086
+
1087
+ new_schema = merge_schemas(col_schemas)
1088
+
1089
+ # Now that we have the new schema, we can validate all the column selections.
1090
+ self._validate_columns(cols, manifest.data_schema, new_schema)
1091
+
1092
+ return SelectRowsSchemaResult(
1093
+ data_schema=new_schema, udfs=udfs, search_results=search_results, sorts=sort_results or None)
1094
+
1095
+ @override
1096
+ def media(self, item_id: str, leaf_path: Path) -> MediaResult:
1097
+ raise NotImplementedError('Media is not yet supported for the DuckDB implementation.')
1098
+
1099
+ def _get_span_from(self, path: PathTuple, manifest: DatasetManifest) -> Optional[PathTuple]:
1100
+ leafs = manifest.data_schema.leafs
1101
+ # Remove the value key so we can check the dtype from leafs.
1102
+ span_path = path[:-1] if path[-1] == VALUE_KEY else path
1103
+ is_span = (span_path in leafs and leafs[span_path].dtype == DataType.STRING_SPAN)
1104
+ return _derived_from_path(path, manifest.data_schema) if is_span else None
1105
+
1106
+ def _leaf_path_to_duckdb_path(self, leaf_path: PathTuple, schema: Schema) -> PathTuple:
1107
+ ((_, duckdb_path),) = self._column_to_duckdb_paths(
1108
+ Column(leaf_path), schema, combine_columns=False, select_leaf=True)
1109
+ return duckdb_path
1110
+
1111
+ def _column_to_duckdb_paths(self,
1112
+ column: Column,
1113
+ schema: Schema,
1114
+ combine_columns: bool,
1115
+ select_leaf: bool = False) -> list[tuple[str, PathTuple]]:
1116
+ path = column.path
1117
+ parquet_manifests: list[Union[SourceManifest, SignalManifest]] = [
1118
+ self._source_manifest, *self._signal_manifests
1119
+ ]
1120
+ duckdb_paths: list[tuple[str, PathTuple]] = []
1121
+ source_has_path = False
1122
+
1123
+ select_leaf = select_leaf or column.signal_udf is not None
1124
+
1125
+ for m in parquet_manifests:
1126
+ if not m.files:
1127
+ continue
1128
+ # Skip this parquet file if it doesn't contain the path.
1129
+ if not schema_contains_path(m.data_schema, path):
1130
+ continue
1131
+
1132
+ if isinstance(m, SourceManifest):
1133
+ source_has_path = True
1134
+
1135
+ if isinstance(m, SignalManifest) and source_has_path and not combine_columns:
1136
+ # Skip this signal if the source already has the path and we are not combining columns.
1137
+ continue
1138
+
1139
+ # Skip this parquet file if the path doesn't have a dtype.
1140
+ if select_leaf and not m.data_schema.get_field(path).dtype:
1141
+ continue
1142
+
1143
+ if isinstance(m, SignalManifest) and path == (UUID_COLUMN,):
1144
+ # Do not select UUID from the signal because it's already in the source.
1145
+ continue
1146
+
1147
+ duckdb_path = path
1148
+ parquet_id = 'source'
1149
+
1150
+ if isinstance(m, SignalManifest):
1151
+ duckdb_path = (m.parquet_id, *path[1:])
1152
+ parquet_id = m.parquet_id
1153
+
1154
+ duckdb_paths.append((parquet_id, duckdb_path))
1155
+
1156
+ if not duckdb_paths:
1157
+ # This path is probably a result of a udf. Make sure the result schema contains it.
1158
+ if not schema.has_field(path):
1159
+ raise ValueError(f'Invalid path "{path}": No manifest contains path. Valid paths: '
1160
+ f'{list(schema.leafs.keys())}')
1161
+
1162
+ return duckdb_paths
1163
+
1164
+ def _normalize_filters(self, filter_likes: Optional[Sequence[FilterLike]],
1165
+ col_aliases: dict[str, PathTuple], udf_aliases: dict[str, PathTuple],
1166
+ manifest: DatasetManifest) -> tuple[list[Filter], list[Filter]]:
1167
+ """Normalize `FilterLike` to `Filter` and split into filters on source and filters on UDFs."""
1168
+ filter_likes = filter_likes or []
1169
+ filters: list[Filter] = []
1170
+ udf_filters: list[Filter] = []
1171
+
1172
+ for filter in filter_likes:
1173
+ # Normalize `FilterLike` to `Filter`.
1174
+ if not isinstance(filter, Filter):
1175
+ if len(filter) == 3:
1176
+ path, op, value = filter # type: ignore
1177
+ elif len(filter) == 2:
1178
+ path, op = filter # type: ignore
1179
+ value = None
1180
+ else:
1181
+ raise ValueError(f'Invalid filter: {filter}. Must be a tuple with 2 or 3 elements.')
1182
+ filter = Filter(path=normalize_path(path), op=op, value=value)
1183
+
1184
+ if str(filter.path[0]) in udf_aliases:
1185
+ udf_filters.append(filter)
1186
+ else:
1187
+ filters.append(filter)
1188
+
1189
+ self._validate_filters(filters, col_aliases, manifest)
1190
+ return filters, udf_filters
1191
+
1192
+ def _normalize_searches(self, searches: Optional[Sequence[Search]],
1193
+ manifest: DatasetManifest) -> None:
1194
+ """Validate searches."""
1195
+ if not searches:
1196
+ return
1197
+
1198
+ for search in searches:
1199
+ search.path = normalize_path(search.path)
1200
+ field = manifest.data_schema.get_field(search.path)
1201
+ if field.dtype != DataType.STRING:
1202
+ raise ValueError(f'Invalid search path: {search.path}. '
1203
+ f'Must be a string field, got dtype {field.dtype}')
1204
+
1205
+ def _search_udfs(self, searches: Optional[Sequence[Search]],
1206
+ manifest: DatasetManifest) -> list[DuckDBSearchUDF]:
1207
+ searches = searches or []
1208
+ """Create a UDF for each search for finding the location of the text with spans."""
1209
+ search_udfs: list[DuckDBSearchUDF] = []
1210
+ for search in searches:
1211
+ search_path = normalize_path(search.path)
1212
+ if search.query.type == 'keyword':
1213
+ udf = Column(path=search_path, signal_udf=SubstringSignal(query=search.query.search))
1214
+ search_udfs.append(
1215
+ DuckDBSearchUDF(
1216
+ udf=udf,
1217
+ search_path=search_path,
1218
+ output_path=(*_col_destination_path(udf), PATH_WILDCARD)))
1219
+ elif search.query.type == 'semantic' or search.query.type == 'concept':
1220
+ embedding = search.query.embedding
1221
+ if not embedding:
1222
+ raise ValueError(f'Please provide an embedding for semantic search. Got search: {search}')
1223
+
1224
+ try:
1225
+ manifest.data_schema.get_field((*search_path, embedding))
1226
+ except Exception as e:
1227
+ raise ValueError(
1228
+ f'Embedding {embedding} has not been computed. '
1229
+ f'Please compute the embedding index before issuing a {search.query.type} query.'
1230
+ ) from e
1231
+
1232
+ search_signal: Optional[Signal] = None
1233
+ if search.query.type == 'semantic':
1234
+ search_signal = SemanticSimilaritySignal(
1235
+ query=search.query.search, embedding=search.query.embedding)
1236
+ elif search.query.type == 'concept':
1237
+ search_signal = ConceptScoreSignal(
1238
+ namespace=search.query.concept_namespace,
1239
+ concept_name=search.query.concept_name,
1240
+ embedding=search.query.embedding)
1241
+
1242
+ # Add the label UDF.
1243
+ concept_labels_signal = ConceptLabelsSignal(
1244
+ namespace=search.query.concept_namespace, concept_name=search.query.concept_name)
1245
+ concept_labels_udf = Column(path=search_path, signal_udf=concept_labels_signal)
1246
+ search_udfs.append(
1247
+ DuckDBSearchUDF(
1248
+ udf=concept_labels_udf,
1249
+ search_path=search_path,
1250
+ output_path=_col_destination_path(concept_labels_udf),
1251
+ sort=None))
1252
+
1253
+ udf = Column(path=search_path, signal_udf=search_signal)
1254
+
1255
+ output_path = _col_destination_path(udf)
1256
+ search_udfs.append(
1257
+ DuckDBSearchUDF(
1258
+ udf=udf,
1259
+ search_path=search_path,
1260
+ output_path=_col_destination_path(udf),
1261
+ sort=((*output_path, PATH_WILDCARD, 'score'), SortOrder.DESC)))
1262
+ else:
1263
+ raise ValueError(f'Unknown search operator {search.query.type}.')
1264
+
1265
+ return search_udfs
1266
+
1267
+ def _create_where(self,
1268
+ manifest: DatasetManifest,
1269
+ filters: list[Filter],
1270
+ searches: Optional[Sequence[Search]] = []) -> list[str]:
1271
+ if not filters and not searches:
1272
+ return []
1273
+ searches = searches or []
1274
+ sql_filter_queries: list[str] = []
1275
+
1276
+ # Add search where queries.
1277
+ for search in searches:
1278
+ duckdb_path = self._leaf_path_to_duckdb_path(
1279
+ normalize_path(search.path), manifest.data_schema)
1280
+ select_str = _select_sql(duckdb_path, flatten=False, unnest=False)
1281
+ if search.query.type == 'keyword':
1282
+ sql_op = 'ILIKE'
1283
+ query_val = _escape_like_value(search.query.search)
1284
+ elif search.query.type == 'semantic' or search.query.type == 'concept':
1285
+ # Semantic search and concepts don't yet filter.
1286
+ continue
1287
+ else:
1288
+ raise ValueError(f'Unknown search operator {search.query.type}.')
1289
+
1290
+ filter_query = f'{select_str} {sql_op} {query_val}'
1291
+
1292
+ sql_filter_queries.append(filter_query)
1293
+
1294
+ # Add filter where queries.
1295
+ binary_ops = set(BinaryOp)
1296
+ unary_ops = set(UnaryOp)
1297
+ list_ops = set(ListOp)
1298
+ for f in filters:
1299
+ duckdb_path = self._leaf_path_to_duckdb_path(f.path, manifest.data_schema)
1300
+ select_str = _select_sql(
1301
+ duckdb_path, flatten=True, unnest=False, span_from=self._get_span_from(f.path, manifest))
1302
+ is_array = any(subpath == PATH_WILDCARD for subpath in f.path)
1303
+
1304
+ nan_filter = ''
1305
+ field = manifest.data_schema.get_field(f.path)
1306
+ filter_nans = field.dtype and is_float(field.dtype)
1307
+
1308
+ if f.op in binary_ops:
1309
+ sql_op = BINARY_OP_TO_SQL[cast(BinaryOp, f.op)]
1310
+ filter_val = cast(FeatureValue, f.value)
1311
+ if isinstance(filter_val, str):
1312
+ filter_val = _escape_string_literal(filter_val)
1313
+ elif isinstance(filter_val, bytes):
1314
+ filter_val = _bytes_to_blob_literal(filter_val)
1315
+ else:
1316
+ filter_val = str(filter_val)
1317
+ if is_array:
1318
+ nan_filter = 'NOT isnan(x) AND' if filter_nans else ''
1319
+ filter_query = (f'len(list_filter({select_str}, '
1320
+ f'x -> {nan_filter} x {sql_op} {filter_val})) > 0')
1321
+ else:
1322
+ nan_filter = f'NOT isnan({select_str}) AND' if filter_nans else ''
1323
+ filter_query = f'{nan_filter} {select_str} {sql_op} {filter_val}'
1324
+ elif f.op in unary_ops:
1325
+ if f.op == UnaryOp.EXISTS:
1326
+ filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
1327
+ else:
1328
+ raise ValueError(f'Unary op: {f.op} is not yet supported')
1329
+ elif f.op in list_ops:
1330
+ if f.op == ListOp.IN:
1331
+ filter_list_val = cast(FeatureListValue, f.value)
1332
+ if not isinstance(filter_list_val, list):
1333
+ raise ValueError('filter with array value can only use the IN comparison')
1334
+ wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
1335
+ filter_val = f'({", ".join(wrapped_filter_val)})'
1336
+ filter_query = f'{select_str} IN {filter_val}'
1337
+ else:
1338
+ raise ValueError(f'List op: {f.op} is not yet supported')
1339
+ else:
1340
+ raise ValueError(f'Invalid filter op: {f.op}')
1341
+ sql_filter_queries.append(filter_query)
1342
+ return sql_filter_queries
1343
+
1344
+ def _execute(self, query: str) -> duckdb.DuckDBPyConnection:
1345
+ """Execute a query in duckdb."""
1346
+ # FastAPI is multi-threaded so we have to create a thread-specific connection cursor to allow
1347
+ # these queries to be thread-safe.
1348
+ local_con = self.con.cursor()
1349
+ if not env('DEBUG', False):
1350
+ return local_con.execute(query)
1351
+
1352
+ # Debug mode.
1353
+ log('Executing:')
1354
+ log(query)
1355
+ with DebugTimer('Query'):
1356
+ return local_con.execute(query)
1357
+
1358
+ def _query(self, query: str) -> list[tuple]:
1359
+ result = self._execute(query)
1360
+ rows = result.fetchall()
1361
+ result.close()
1362
+ return rows
1363
+
1364
+ def _query_df(self, query: str) -> pd.DataFrame:
1365
+ """Execute a query that returns a data frame."""
1366
+ result = self._execute(query)
1367
+ df = _replace_nan_with_none(result.df())
1368
+ result.close()
1369
+ return df
1370
+
1371
+ def _path_to_col(self, path: Path, quote_each_part: bool = True) -> str:
1372
+ """Convert a path to a column name."""
1373
+ if isinstance(path, str):
1374
+ path = (path,)
1375
+ return '.'.join([
1376
+ f'{_escape_col_name(path_comp)}' if quote_each_part else str(path_comp) for path_comp in path
1377
+ ])
1378
+
1379
+ @override
1380
+ def to_json(self, filepath: Union[str, pathlib.Path], jsonl: bool = True) -> None:
1381
+ self._execute(f"COPY t TO '{filepath}' (FORMAT JSON, ARRAY {'FALSE' if jsonl else 'TRUE'})")
1382
+ log(f'Dataset exported to {filepath}')
1383
+
1384
+ @override
1385
+ def to_pandas(self) -> pd.DataFrame:
1386
+ return self._query_df('SELECT * FROM t')
1387
+
1388
+ @override
1389
+ def to_csv(self, filepath: Union[str, pathlib.Path]) -> None:
1390
+ self._execute(f"COPY t TO '{filepath}' (FORMAT CSV, HEADER)")
1391
+ log(f'Dataset exported to {filepath}')
1392
+
1393
+ @override
1394
+ def to_parquet(self, filepath: Union[str, pathlib.Path]) -> None:
1395
+ self._execute(f"COPY t TO '{filepath}' (FORMAT PARQUET)")
1396
+ log(f'Dataset exported to {filepath}')
1397
+
1398
+
1399
+ def _escape_string_literal(string: str) -> str:
1400
+ string = string.replace("'", "''")
1401
+ return f"'{string}'"
1402
+
1403
+
1404
+ def _escape_col_name(col_name: str) -> str:
1405
+ col_name = col_name.replace('"', '""')
1406
+ return f'"{col_name}"'
1407
+
1408
+
1409
+ def _escape_like_value(value: str) -> str:
1410
+ value = value.replace('%', '\\%').replace('_', '\\_')
1411
+ return f"'%{value}%' ESCAPE '\\'"
1412
+
1413
+
1414
+ def _inner_select(sub_paths: list[PathTuple],
1415
+ inner_var: Optional[str] = None,
1416
+ empty: bool = False,
1417
+ span_from: Optional[PathTuple] = None) -> str:
1418
+ """Recursively generate the inner select statement for a list of sub paths."""
1419
+ current_sub_path = sub_paths[0]
1420
+ lambda_var = inner_var + 'x' if inner_var else 'x'
1421
+ if not inner_var:
1422
+ lambda_var = 'x'
1423
+ inner_var = _escape_col_name(current_sub_path[0])
1424
+ current_sub_path = current_sub_path[1:]
1425
+ # Select the path inside structs. E.g. x['a']['b']['c'] given current_sub_path = [a, b, c].
1426
+ path_key = inner_var + ''.join([f'[{_escape_string_literal(p)}]' for p in current_sub_path])
1427
+ if len(sub_paths) == 1:
1428
+ if span_from:
1429
+ derived_col = _select_sql(span_from, flatten=False, unnest=False)
1430
+ path_key = (f'{derived_col}[{path_key}.{VALUE_KEY}.{TEXT_SPAN_START_FEATURE}+1:'
1431
+ f'{path_key}.{VALUE_KEY}.{TEXT_SPAN_END_FEATURE}]')
1432
+ return 'NULL' if empty else path_key
1433
+ return (f'list_transform({path_key}, {lambda_var} -> '
1434
+ f'{_inner_select(sub_paths[1:], lambda_var, empty, span_from)})')
1435
+
1436
+
1437
+ def _split_path_into_subpaths_of_lists(leaf_path: PathTuple) -> list[PathTuple]:
1438
+ """Split a path into a subpath of lists.
1439
+
1440
+ E.g. [a, b, c, *, d, *, *] gets splits [[a, b, c], [d], [], []].
1441
+ """
1442
+ sub_paths: list[PathTuple] = []
1443
+ offset = 0
1444
+ while offset <= len(leaf_path):
1445
+ new_offset = leaf_path.index(PATH_WILDCARD,
1446
+ offset) if PATH_WILDCARD in leaf_path[offset:] else len(leaf_path)
1447
+ sub_path = leaf_path[offset:new_offset]
1448
+ sub_paths.append(sub_path)
1449
+ offset = new_offset + 1
1450
+ return sub_paths
1451
+
1452
+
1453
+ def _select_sql(path: PathTuple,
1454
+ flatten: bool,
1455
+ unnest: bool,
1456
+ empty: bool = False,
1457
+ span_from: Optional[PathTuple] = None) -> str:
1458
+ """Create a select column for a path.
1459
+
1460
+ Args:
1461
+ path: A path to a feature. E.g. ['a', 'b', 'c'].
1462
+ flatten: Whether to flatten the result.
1463
+ unnest: Whether to unnest the result.
1464
+ empty: Whether to return an empty list (used for embedding signals that don't need the data).
1465
+ span_from: The path this span is derived from. If specified, the span will be resolved
1466
+ to a substring of the original string.
1467
+ """
1468
+ sub_paths = _split_path_into_subpaths_of_lists(path)
1469
+ selection = _inner_select(sub_paths, None, empty, span_from)
1470
+ # We only flatten when the result of a nested list to avoid segfault.
1471
+ is_result_nested_list = len(sub_paths) >= 3 # E.g. subPaths = [[a, b, c], *, *].
1472
+ if flatten and is_result_nested_list:
1473
+ selection = f'flatten({selection})'
1474
+ # We only unnest when the result is a list. // E.g. subPaths = [[a, b, c], *].
1475
+ is_result_a_list = len(sub_paths) >= 2
1476
+ if unnest and is_result_a_list:
1477
+ selection = f'unnest({selection})'
1478
+ return selection
1479
+
1480
+
1481
+ def read_source_manifest(dataset_path: str) -> SourceManifest:
1482
+ """Read the manifest file."""
1483
+ with open_file(os.path.join(dataset_path, MANIFEST_FILENAME), 'r') as f:
1484
+ return SourceManifest.parse_raw(f.read())
1485
+
1486
+
1487
+ def _signal_dir(enriched_path: PathTuple) -> str:
1488
+ """Get the filename prefix for a signal parquet file."""
1489
+ path_without_wildcards = (p for p in enriched_path if p != PATH_WILDCARD)
1490
+ return os.path.join(*path_without_wildcards)
1491
+
1492
+
1493
+ def split_column_name(column: str, split_name: str) -> str:
1494
+ """Get the name of a split column."""
1495
+ return f'{column}.{split_name}'
1496
+
1497
+
1498
+ def split_parquet_prefix(column_name: str, splitter_name: str) -> str:
1499
+ """Get the filename prefix for a split parquet file."""
1500
+ return f'{column_name}.{splitter_name}'
1501
+
1502
+
1503
+ def _bytes_to_blob_literal(bytes: bytes) -> str:
1504
+ """Convert bytes to a blob literal."""
1505
+ escaped_hex = re.sub(r'(.{2})', r'\\x\1', bytes.hex())
1506
+ return f"'{escaped_hex}'::BLOB"
1507
+
1508
+
1509
+ class SignalManifest(BaseModel):
1510
+ """The manifest that describes a signal computation including schema and parquet files."""
1511
+ # List of a parquet filepaths storing the data. The paths are relative to the manifest.
1512
+ files: list[str]
1513
+
1514
+ # An identifier for this parquet table. Will be used as the view name in SQL.
1515
+ parquet_id: str
1516
+
1517
+ data_schema: Schema
1518
+ signal: Signal
1519
+
1520
+ # The column path that this signal is derived from.
1521
+ enriched_path: PathTuple
1522
+
1523
+ # The name of the vector store. Present when the signal is an embedding.
1524
+ vector_store: Optional[str] = None
1525
+
1526
+ @validator('signal', pre=True)
1527
+ def parse_signal(cls, signal: dict) -> Signal:
1528
+ """Parse a signal to its specific subclass instance."""
1529
+ return resolve_signal(signal)
1530
+
1531
+
1532
+ def _merge_cells(dest_cell: Item, source_cell: Item) -> Item:
1533
+ if source_cell is None or isinstance(source_cell, float) and math.isnan(source_cell):
1534
+ # Nothing to merge here (missing value).
1535
+ return dest_cell
1536
+ if isinstance(dest_cell, dict):
1537
+ if isinstance(source_cell, list):
1538
+ raise ValueError(f'Failed to merge cells. Destination is a dict ({dest_cell!r}), '
1539
+ f'but source is a list ({source_cell!r}).')
1540
+ if isinstance(source_cell, dict):
1541
+ res = {**dest_cell}
1542
+ for key, value in source_cell.items():
1543
+ res[key] = (value if key not in dest_cell else _merge_cells(dest_cell[key], value))
1544
+ return res
1545
+ else:
1546
+ return {VALUE_KEY: source_cell, **dest_cell}
1547
+ elif isinstance(dest_cell, list):
1548
+ if not isinstance(source_cell, list):
1549
+ raise ValueError('Failed to merge cells. Destination is a list, but source is not.')
1550
+ return [
1551
+ _merge_cells(dest_subcell, source_subcell)
1552
+ for dest_subcell, source_subcell in zip(dest_cell, source_cell)
1553
+ ]
1554
+ else:
1555
+ # The destination is a primitive.
1556
+ if isinstance(source_cell, list):
1557
+ raise ValueError(f'Failed to merge cells. Destination is a primitive ({dest_cell!r}), '
1558
+ f'but source is a list ({source_cell!r}).')
1559
+ if isinstance(source_cell, dict):
1560
+ return {VALUE_KEY: dest_cell, **source_cell}
1561
+ else:
1562
+ # Primitives can be merged together if they are equal. This can happen if a user selects a
1563
+ # column that is the child of another.
1564
+ # NOTE: This can be removed if we fix https://github.com/lilacai/lilac/issues/166.
1565
+ if source_cell != dest_cell:
1566
+ raise ValueError(f'Cannot merge source "{source_cell!r}" into destination "{dest_cell!r}"')
1567
+ return dest_cell
1568
+
1569
+
1570
+ def merge_series(destination: pd.Series, source: pd.Series) -> list[Item]:
1571
+ """Merge two series of values recursively."""
1572
+ return _merge_cells(destination.tolist(), source.tolist())
1573
+
1574
+
1575
+ def _unique_alias(column: Column) -> str:
1576
+ """Get a unique alias for a selection column."""
1577
+ if column.signal_udf:
1578
+ return make_parquet_id(column.signal_udf, column.path)
1579
+ return '.'.join(map(str, column.path))
1580
+
1581
+
1582
+ def _path_contains(parent_path: PathTuple, child_path: PathTuple) -> bool:
1583
+ """Check if a path contains another path."""
1584
+ if len(parent_path) > len(child_path):
1585
+ return False
1586
+ return all(parent_path[i] == child_path[i] for i in range(len(parent_path)))
1587
+
1588
+
1589
+ def _path_to_udf_duckdb_path(path: PathTuple,
1590
+ path_to_udf_col_name: dict[PathTuple, str]) -> Optional[PathTuple]:
1591
+ first_subpath, *rest_of_path = path
1592
+ for parent_path, udf_col_name in path_to_udf_col_name.items():
1593
+ # If the user selected udf(document.*.text) as "udf" and wanted to sort by "udf.len", we need to
1594
+ # sort by "udf.*.len" where the "*" came from the fact that the udf was applied to a list of
1595
+ # "text" fields.
1596
+ wildcards = [x for x in parent_path if x == PATH_WILDCARD]
1597
+ if _path_contains(parent_path, path):
1598
+ return (udf_col_name, *wildcards, *path[len(parent_path):])
1599
+ elif first_subpath == udf_col_name:
1600
+ return (udf_col_name, *wildcards, *rest_of_path)
1601
+
1602
+ return None
1603
+
1604
+
1605
+ def _col_destination_path(column: Column, is_computed_signal: Optional[bool] = False) -> PathTuple:
1606
+ """Get the destination path where the output of this selection column will be stored."""
1607
+ source_path = column.path
1608
+
1609
+ if not column.signal_udf:
1610
+ return source_path
1611
+
1612
+ signal_key = column.signal_udf.key(is_computed_signal=is_computed_signal)
1613
+ # If we are enriching a value we should store the signal data in the value's parent.
1614
+ if source_path[-1] == VALUE_KEY:
1615
+ dest_path = (*source_path[:-1], signal_key)
1616
+ else:
1617
+ dest_path = (*source_path, signal_key)
1618
+
1619
+ return dest_path
1620
+
1621
+
1622
+ def _root_column(manifest: SignalManifest) -> str:
1623
+ """Returns the root column of a signal manifest."""
1624
+ field_keys = manifest.data_schema.fields.keys()
1625
+ if len(field_keys) != 2:
1626
+ raise ValueError('Expected exactly two fields in signal manifest, '
1627
+ f'the row UUID and root this signal is enriching. Got {field_keys}.')
1628
+ return next(filter(lambda field: field != UUID_COLUMN, manifest.data_schema.fields.keys()))
1629
+
1630
+
1631
+ def _derived_from_path(path: PathTuple, schema: Schema) -> PathTuple:
1632
+ # Find the closest parent of `path` that is a signal root.
1633
+ for i in reversed(range(len(path))):
1634
+ sub_path = path[:i]
1635
+ if schema.get_field(sub_path).signal is not None:
1636
+ # Skip the signal name at the end to get the source path that was enriched.
1637
+ return sub_path[:-1]
1638
+ raise ValueError('Cannot find the source path for the enriched path: {path}')
1639
+
1640
+
1641
+ def _make_schema_from_path(path: PathTuple, field: Field) -> Schema:
1642
+ """Returns a schema that contains only the given path."""
1643
+ for sub_path in reversed(path):
1644
+ if sub_path == PATH_WILDCARD:
1645
+ field = Field(repeated_field=field)
1646
+ else:
1647
+ field = Field(fields={sub_path: field})
1648
+ if not field.fields:
1649
+ raise ValueError(f'Invalid path: {path}. Must contain at least one field name.')
1650
+ return Schema(fields=field.fields)
1651
+
1652
+
1653
+ def _replace_nan_with_none(df: pd.DataFrame) -> pd.DataFrame:
1654
+ """DuckDB returns np.nan for missing field in string column, replace with None for correctness."""
1655
+ # TODO(https://github.com/duckdb/duckdb/issues/4066): Remove this once duckdb fixes upstream.
1656
+ for col in df.columns:
1657
+ if is_object_dtype(df[col]):
1658
+ df[col].replace(np.nan, None, inplace=True)
1659
+ return df
1660
+
1661
+
1662
+ def _offset_any_span(offset: int, item: Item, schema: Field) -> None:
1663
+ """Offsets any spans inplace by the given parent offset."""
1664
+ if schema.dtype == DataType.STRING_SPAN:
1665
+ item = cast(dict, item)
1666
+ item[VALUE_KEY][TEXT_SPAN_START_FEATURE] += offset
1667
+ item[VALUE_KEY][TEXT_SPAN_END_FEATURE] += offset
1668
+ if schema.fields:
1669
+ item = cast(dict, item)
1670
+ for key, sub_schema in schema.fields.items():
1671
+ _offset_any_span(offset, item[key], sub_schema)
1672
+ if schema.repeated_field:
1673
+ item = cast(list, item)
1674
+ for sub_item in item:
1675
+ _offset_any_span(offset, sub_item, schema.repeated_field)
1676
+
1677
+
1678
+ def _schema_has_spans(field: Field) -> bool:
1679
+ if field.dtype and field.dtype == DataType.STRING_SPAN:
1680
+ return True
1681
+ if field.fields:
1682
+ children_have_spans = any(_schema_has_spans(sub_field) for sub_field in field.fields.values())
1683
+ if children_have_spans:
1684
+ return True
1685
+ if field.repeated_field:
1686
+ return _schema_has_spans(field.repeated_field)
1687
+ return False
1688
+
1689
+
1690
+ def _normalize_bins(bins: Optional[Union[Sequence[Bin], Sequence[float]]]) -> Optional[list[Bin]]:
1691
+ if bins is None:
1692
+ return None
1693
+ if not isinstance(bins[0], (float, int)):
1694
+ return cast(list[Bin], bins)
1695
+ named_bins: list[Bin] = []
1696
+ for i in range(len(bins) + 1):
1697
+ start = cast(float, bins[i - 1]) if i > 0 else None
1698
+ end = cast(float, bins[i]) if i < len(bins) else None
1699
+ named_bins.append((str(i), start, end))
1700
+ return named_bins
1701
+
1702
+
1703
+ def _auto_bins(stats: StatsResult, num_bins: int) -> list[Bin]:
1704
+ min_val = cast(float, stats.min_val)
1705
+ max_val = cast(float, stats.max_val)
1706
+ bin_width = (max_val - min_val) / num_bins
1707
+ bins: list[Bin] = []
1708
+ for i in range(num_bins):
1709
+ start = None if i == 0 else min_val + i * bin_width
1710
+ end = None if i == num_bins - 1 else min_val + (i + 1) * bin_width
1711
+ bins.append((str(i), start, end))
1712
+ return bins
1713
+
1714
+
1715
+ def _settings_filepath(namespace: str, dataset_name: str) -> str:
1716
+ return os.path.join(
1717
+ get_dataset_output_dir(data_path(), namespace, dataset_name), DATASET_SETTINGS_FILENAME)
lilac/data/dataset_test_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests utils of for dataset_test."""
2
+ import os
3
+ import pathlib
4
+ from datetime import datetime
5
+ from typing import Optional, Type, cast
6
+
7
+ import numpy as np
8
+ from typing_extensions import Protocol
9
+
10
+ from ..embeddings.vector_store import VectorDBIndex
11
+ from ..schema import (
12
+ MANIFEST_FILENAME,
13
+ PARQUET_FILENAME_PREFIX,
14
+ VALUE_KEY,
15
+ DataType,
16
+ Field,
17
+ Item,
18
+ PathKey,
19
+ Schema,
20
+ SourceManifest,
21
+ )
22
+ from ..utils import get_dataset_output_dir, open_file
23
+ from .dataset import Dataset
24
+ from .dataset_utils import is_primitive, write_items_to_parquet
25
+
26
+ TEST_NAMESPACE = 'test_namespace'
27
+ TEST_DATASET_NAME = 'test_dataset'
28
+
29
+
30
+ def _infer_dtype(value: Item) -> DataType:
31
+ if isinstance(value, str):
32
+ return DataType.STRING
33
+ elif isinstance(value, bool):
34
+ return DataType.BOOLEAN
35
+ elif isinstance(value, bytes):
36
+ return DataType.BINARY
37
+ elif isinstance(value, float):
38
+ return DataType.FLOAT32
39
+ elif isinstance(value, int):
40
+ return DataType.INT32
41
+ elif isinstance(value, datetime):
42
+ return DataType.TIMESTAMP
43
+ else:
44
+ raise ValueError(f'Cannot infer dtype of primitive value: {value}')
45
+
46
+
47
+ def _infer_field(item: Item) -> Field:
48
+ """Infer the schema from the items."""
49
+ if isinstance(item, dict):
50
+ fields: dict[str, Field] = {}
51
+ for k, v in item.items():
52
+ fields[k] = _infer_field(cast(Item, v))
53
+ dtype = None
54
+ if VALUE_KEY in fields:
55
+ dtype = fields[VALUE_KEY].dtype
56
+ del fields[VALUE_KEY]
57
+ return Field(fields=fields, dtype=dtype)
58
+ elif is_primitive(item):
59
+ return Field(dtype=_infer_dtype(item))
60
+ elif isinstance(item, list):
61
+ return Field(repeated_field=_infer_field(item[0]))
62
+ else:
63
+ raise ValueError(f'Cannot infer schema of item: {item}')
64
+
65
+
66
+ def _infer_schema(items: list[Item]) -> Schema:
67
+ """Infer the schema from the items."""
68
+ schema = Schema(fields={})
69
+ for item in items:
70
+ field = _infer_field(item)
71
+ if not field.fields:
72
+ raise ValueError(f'Invalid schema of item. Expected an object, but got: {item}')
73
+ schema.fields = {**schema.fields, **field.fields}
74
+ return schema
75
+
76
+
77
+ class TestDataMaker(Protocol):
78
+ """A function that creates a test dataset."""
79
+
80
+ def __call__(self, items: list[Item], schema: Optional[Schema] = None) -> Dataset:
81
+ """Create a test dataset."""
82
+ ...
83
+
84
+
85
+ def make_dataset(dataset_cls: Type[Dataset],
86
+ tmp_path: pathlib.Path,
87
+ items: list[Item],
88
+ schema: Optional[Schema] = None) -> Dataset:
89
+ """Create a test dataset."""
90
+ schema = schema or _infer_schema(items)
91
+ _write_items(tmp_path, TEST_DATASET_NAME, items, schema)
92
+ return dataset_cls(TEST_NAMESPACE, TEST_DATASET_NAME)
93
+
94
+
95
+ def _write_items(tmpdir: pathlib.Path, dataset_name: str, items: list[Item],
96
+ schema: Schema) -> None:
97
+ """Write the items JSON to the dataset format: manifest.json and parquet files."""
98
+ source_dir = get_dataset_output_dir(str(tmpdir), TEST_NAMESPACE, dataset_name)
99
+ os.makedirs(source_dir)
100
+
101
+ simple_parquet_files, _ = write_items_to_parquet(
102
+ items, source_dir, schema, filename_prefix=PARQUET_FILENAME_PREFIX, shard_index=0, num_shards=1)
103
+ manifest = SourceManifest(files=[simple_parquet_files], data_schema=schema)
104
+ with open_file(os.path.join(source_dir, MANIFEST_FILENAME), 'w') as f:
105
+ f.write(manifest.json(indent=2, exclude_none=True))
106
+
107
+
108
+ def enriched_item(value: Optional[Item] = None, metadata: dict[str, Item] = {}) -> Item:
109
+ """Wrap a value in a dict with the value key."""
110
+ return {VALUE_KEY: value, **metadata}
111
+
112
+
113
+ def make_vector_index(vector_store: str, vector_dict: dict[PathKey,
114
+ list[list[float]]]) -> VectorDBIndex:
115
+ """Make a vector index from a dictionary of vector keys to vectors."""
116
+ embeddings: list[np.ndarray] = []
117
+ spans: list[tuple[PathKey, list[tuple[int, int]]]] = []
118
+ for path_key, vectors in vector_dict.items():
119
+ vector_spans: list[tuple[int, int]] = []
120
+ for i, vector in enumerate(vectors):
121
+ embeddings.append(np.array(vector))
122
+ vector_spans.append((0, 0))
123
+ spans.append((path_key, vector_spans))
124
+
125
+ vector_index = VectorDBIndex(vector_store)
126
+ vector_index.add(spans, np.array(embeddings))
127
+ return vector_index
lilac/data/dataset_utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for working with datasets."""
2
+
3
+ import gc
4
+ import json
5
+ import math
6
+ import os
7
+ import pprint
8
+ import secrets
9
+ from collections.abc import Iterable
10
+ from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar, Union, cast
11
+
12
+ import numpy as np
13
+ import pyarrow as pa
14
+
15
+ from ..batch_utils import deep_flatten
16
+ from ..embeddings.vector_store import VectorDBIndex
17
+ from ..env import env
18
+ from ..parquet_writer import ParquetWriter
19
+ from ..schema import (
20
+ EMBEDDING_KEY,
21
+ PATH_WILDCARD,
22
+ TEXT_SPAN_END_FEATURE,
23
+ TEXT_SPAN_START_FEATURE,
24
+ UUID_COLUMN,
25
+ VALUE_KEY,
26
+ Field,
27
+ Item,
28
+ PathKey,
29
+ PathTuple,
30
+ Schema,
31
+ VectorKey,
32
+ field,
33
+ schema,
34
+ schema_to_arrow_schema,
35
+ )
36
+ from ..signals.signal import Signal
37
+ from ..utils import is_primitive, log, open_file
38
+
39
+
40
+ def _replace_embeddings_with_none(input: Union[Item, Item]) -> Union[Item, Item]:
41
+ if isinstance(input, np.ndarray):
42
+ return None
43
+ if isinstance(input, dict):
44
+ return {k: _replace_embeddings_with_none(v) for k, v in input.items()}
45
+ if isinstance(input, list):
46
+ return [_replace_embeddings_with_none(v) for v in input]
47
+
48
+ return input
49
+
50
+
51
+ def replace_embeddings_with_none(input: Union[Item, Item]) -> Item:
52
+ """Replaces all embeddings with None."""
53
+ return cast(Item, _replace_embeddings_with_none(input))
54
+
55
+
56
+ def count_primitives(input: Union[Iterable, Iterator]) -> int:
57
+ """Iterate through each element of the input, flattening each one, computing a count.
58
+
59
+ Sum the final set of counts. This is the important iterable not to exhaust.
60
+ """
61
+ return sum((len(list(deep_flatten(i))) for i in input))
62
+
63
+
64
+ def _wrap_value_in_dict(input: Union[object, dict], props: PathTuple) -> Union[object, dict]:
65
+ # If the signal produced no value, or nan, we should return None so the parquet value is sparse.
66
+ if isinstance(input, float) and math.isnan(input):
67
+ input = None
68
+ for prop in reversed(props):
69
+ input = {prop: input}
70
+ return input
71
+
72
+
73
+ def _wrap_in_dicts(input: Union[object, Iterable[object]],
74
+ spec: list[PathTuple]) -> Union[object, Iterable[object]]:
75
+ """Wraps an object or iterable in a dict according to the spec."""
76
+ props = spec[0] if spec else tuple()
77
+ if len(spec) == 1:
78
+ return _wrap_value_in_dict(input, props)
79
+ if input is None or isinstance(input, float) and math.isnan(input):
80
+ # Return empty dict for missing inputs.
81
+ return {}
82
+ res = [_wrap_in_dicts(elem, spec[1:]) for elem in cast(Iterable, input)]
83
+ return _wrap_value_in_dict(res, props)
84
+
85
+
86
+ def wrap_in_dicts(input: Iterable[object], spec: list[PathTuple]) -> Iterable[object]:
87
+ """Wraps an object or iterable in a dict according to the spec."""
88
+ return [_wrap_in_dicts(elem, spec) for elem in input]
89
+
90
+
91
+ def _merge_field_into(schema: Field, destination: Field) -> None:
92
+ if isinstance(schema, Field):
93
+ destination.signal = destination.signal or schema.signal
94
+ destination.dtype = destination.dtype or schema.dtype
95
+ if schema.fields:
96
+ destination.fields = destination.fields or {}
97
+ for field_name, subfield in schema.fields.items():
98
+ if field_name not in destination.fields:
99
+ destination.fields[field_name] = subfield.copy(deep=True)
100
+ else:
101
+ _merge_field_into(subfield, destination.fields[field_name])
102
+ elif schema.repeated_field:
103
+ if not destination.repeated_field:
104
+ raise ValueError('Failed to merge schemas. Origin schema is repeated, but destination is not')
105
+ _merge_field_into(schema.repeated_field, destination.repeated_field)
106
+ else:
107
+ if destination.dtype != schema.dtype:
108
+ raise ValueError(f'Failed to merge schemas. Origin schema has dtype {schema.dtype}, '
109
+ f'but destination has dtype {destination.dtype}')
110
+
111
+
112
+ def merge_schemas(schemas: Sequence[Union[Schema, Field]]) -> Schema:
113
+ """Merge a list of schemas."""
114
+ merged_schema = Schema(fields={})
115
+ for s in schemas:
116
+ _merge_field_into(cast(Field, s), cast(Field, merged_schema))
117
+ return merged_schema
118
+
119
+
120
+ def schema_contains_path(schema: Schema, path: PathTuple) -> bool:
121
+ """Check if a schema contains a path."""
122
+ current_field = cast(Field, schema)
123
+ for path_part in path:
124
+ if path_part == PATH_WILDCARD:
125
+ if current_field.repeated_field is None:
126
+ return False
127
+ current_field = current_field.repeated_field
128
+ else:
129
+ if current_field.fields is None or path_part not in current_field.fields:
130
+ return False
131
+ current_field = current_field.fields[str(path_part)]
132
+ return True
133
+
134
+
135
+ def create_signal_schema(signal: Signal, source_path: PathTuple, current_schema: Schema) -> Schema:
136
+ """Create a schema describing the enriched fields added an enrichment."""
137
+ leafs = current_schema.leafs
138
+ # Validate that the enrich fields are actually a valid leaf path.
139
+ if source_path not in leafs:
140
+ raise ValueError(f'"{source_path}" is not a valid leaf path. Leaf paths: {leafs.keys()}')
141
+
142
+ signal_schema = signal.fields()
143
+ signal_schema.signal = signal.dict()
144
+
145
+ enriched_schema = field(fields={signal.key(is_computed_signal=True): signal_schema})
146
+
147
+ for path_part in reversed(source_path):
148
+ if path_part == PATH_WILDCARD:
149
+ enriched_schema = Field(repeated_field=enriched_schema)
150
+ else:
151
+ enriched_schema = Field(fields={path_part: enriched_schema})
152
+
153
+ if not enriched_schema.fields:
154
+ raise ValueError('This should not happen since enriched_schema always has fields (see above)')
155
+
156
+ return schema({UUID_COLUMN: 'string', **cast(dict, enriched_schema.fields)})
157
+
158
+
159
+ def write_embeddings_to_disk(vector_store: str, uuids: Iterable[str], signal_items: Iterable[Item],
160
+ output_dir: str) -> None:
161
+ """Write a set of embeddings to disk."""
162
+
163
+ def embedding_predicate(input: Any) -> bool:
164
+ return (isinstance(input, list) and len(input) > 0 and isinstance(input[0], dict) and
165
+ EMBEDDING_KEY in input[0])
166
+
167
+ path_keys = flatten_keys(uuids, signal_items, is_primitive_predicate=embedding_predicate)
168
+ all_embeddings = cast(Iterable[Item],
169
+ deep_flatten(signal_items, is_primitive_predicate=embedding_predicate))
170
+
171
+ embedding_vectors: list[np.ndarray] = []
172
+ all_spans: list[tuple[PathKey, list[tuple[int, int]]]] = []
173
+ for path_key, embeddings in zip(path_keys, all_embeddings):
174
+ if not path_key or not embeddings:
175
+ # Sparse embeddings may not have an embedding for every key.
176
+ continue
177
+
178
+ spans: list[tuple[int, int]] = []
179
+ for e in embeddings:
180
+ span = e[VALUE_KEY]
181
+ vector = e[EMBEDDING_KEY]
182
+ # We squeeze here because embedding functions can return outer dimensions of 1.
183
+ embedding_vectors.append(vector.reshape(-1))
184
+ spans.append((span[TEXT_SPAN_START_FEATURE], span[TEXT_SPAN_END_FEATURE]))
185
+ all_spans.append((path_key, spans))
186
+ embedding_matrix = np.array(embedding_vectors)
187
+ del path_keys, all_embeddings, embedding_vectors
188
+ gc.collect()
189
+
190
+ # Write to disk.
191
+ vector_index = VectorDBIndex(vector_store)
192
+ vector_index.add(all_spans, embedding_matrix)
193
+ vector_index.save(output_dir)
194
+
195
+ del vector_index
196
+ gc.collect()
197
+
198
+
199
+ def write_items_to_parquet(items: Iterable[Item], output_dir: str, schema: Schema,
200
+ filename_prefix: str, shard_index: int,
201
+ num_shards: int) -> tuple[str, int]:
202
+ """Write a set of items to a parquet file, in columnar format."""
203
+ arrow_schema = schema_to_arrow_schema(schema)
204
+ out_filename = parquet_filename(filename_prefix, shard_index, num_shards)
205
+ filepath = os.path.join(output_dir, out_filename)
206
+ f = open_file(filepath, mode='wb')
207
+ writer = ParquetWriter(schema)
208
+ writer.open(f)
209
+ debug = env('DEBUG', False)
210
+ num_items = 0
211
+ for item in items:
212
+ # Add a UUID column.
213
+ if UUID_COLUMN not in item:
214
+ item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
215
+ if debug:
216
+ try:
217
+ _validate(item, arrow_schema)
218
+ except Exception as e:
219
+ raise ValueError(f'Error validating item: {json.dumps(item)}') from e
220
+ writer.write(item)
221
+ num_items += 1
222
+ writer.close()
223
+ f.close()
224
+ return out_filename, num_items
225
+
226
+
227
+ def _validate(item: Item, schema: pa.Schema) -> None:
228
+ # Try to parse the item using the inferred schema.
229
+ try:
230
+ pa.RecordBatch.from_pylist([item], schema=schema)
231
+ except pa.ArrowTypeError:
232
+ log('Failed to parse arrow item using the arrow schema.')
233
+ log('Item:')
234
+ log(pprint.pformat(item, indent=2))
235
+ log('Arrow schema:')
236
+ log(schema)
237
+ raise # Re-raise the same exception, same stacktrace.
238
+
239
+
240
+ def parquet_filename(prefix: str, shard_index: int, num_shards: int) -> str:
241
+ """Return the filename for a parquet file."""
242
+ return f'{prefix}-{shard_index:05d}-of-{num_shards:05d}.parquet'
243
+
244
+
245
+ def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
246
+ is_primitive_predicate: Callable[[object], bool]) -> Iterator[VectorKey]:
247
+ if is_primitive_predicate(nested_input) or is_primitive(nested_input) or isinstance(
248
+ nested_input, dict):
249
+ yield (uuid, *location)
250
+ return
251
+
252
+ for i, input in enumerate(nested_input):
253
+ yield from _flatten_keys(uuid, input, [*location, i], is_primitive_predicate)
254
+
255
+
256
+ def flatten_keys(
257
+ uuids: Iterable[str],
258
+ nested_input: Iterable,
259
+ is_primitive_predicate: Callable[[object],
260
+ bool] = is_primitive) -> Iterator[Optional[VectorKey]]:
261
+ """Flatten the uuid keys of a nested input."""
262
+ for uuid, input in zip(uuids, nested_input):
263
+ if input is None:
264
+ yield None
265
+ continue
266
+ yield from _flatten_keys(uuid, input, [], is_primitive_predicate)
267
+
268
+
269
+ Tin = TypeVar('Tin')
270
+ Tout = TypeVar('Tout')
271
+
272
+
273
+ def sparse_to_dense_compute(
274
+ sparse_input: Iterator[Optional[Tin]],
275
+ func: Callable[[Iterable[Tin]], Iterable[Tout]]) -> Iterator[Optional[Tout]]:
276
+ """Densifies the input before calling the provided `func` and sparsifies the output."""
277
+ locations: list[int] = []
278
+ total_size: int = 0
279
+
280
+ def densify(x: Iterator[Optional[Tin]]) -> Iterator[Tin]:
281
+ nonlocal locations, total_size
282
+ for i, value in enumerate(x):
283
+ total_size += 1
284
+ if value is not None:
285
+ locations.append(i)
286
+ yield value
287
+
288
+ dense_input = densify(sparse_input)
289
+ dense_output = iter(func(dense_input))
290
+ index = 0
291
+
292
+ location_index = 0
293
+
294
+ while True:
295
+ try:
296
+ out = next(dense_output)
297
+ out_index = locations[location_index]
298
+ while index < out_index:
299
+ yield None
300
+ index += 1
301
+ yield out
302
+ location_index += 1
303
+ index += 1
304
+ except StopIteration:
305
+ while index < total_size:
306
+ yield None
307
+ index += 1
308
+ return
lilac/data/duckdb_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for duckdb."""
2
+ import os
3
+
4
+ import duckdb
5
+
6
+ from ..env import data_path, env
7
+
8
+
9
+ def duckdb_setup(con: duckdb.DuckDBPyConnection) -> str:
10
+ """Setup DuckDB. This includes setting up the extensions directory and GCS access."""
11
+ con.execute(f"""
12
+ SET extension_directory='{os.path.join(data_path(), '.duckdb')}';
13
+ """)
14
+
15
+ con.install_extension('httpfs')
16
+ con.load_extension('httpfs')
17
+
18
+ if env('GCS_REGION'):
19
+ return f"""
20
+ SET s3_region='{env('GCS_REGION')}';
21
+ SET s3_access_key_id='{env('GCS_ACCESS_KEY')}';
22
+ SET s3_secret_access_key='{env('GCS_SECRET_KEY')}';
23
+ SET s3_endpoint='storage.googleapis.com';
24
+ """
25
+ return ''
lilac/data_loader.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A data loader standalone binary. This should only be run as a script to load a dataset.
2
+
3
+ To run the source loader as a binary directly:
4
+
5
+ poetry run python -m lilac.data_loader \
6
+ --dataset_name=movies_dataset \
7
+ --output_dir=./data/ \
8
+ --config_path=./datasets/the_movies_dataset.json
9
+ """
10
+ import os
11
+ import pathlib
12
+ import uuid
13
+ from typing import Iterable, Optional, Union
14
+
15
+ import pandas as pd
16
+
17
+ from .data.dataset import Dataset
18
+ from .data.dataset_utils import write_items_to_parquet
19
+ from .db_manager import get_dataset
20
+ from .env import data_path
21
+ from .schema import (
22
+ MANIFEST_FILENAME,
23
+ PARQUET_FILENAME_PREFIX,
24
+ UUID_COLUMN,
25
+ Field,
26
+ Item,
27
+ Schema,
28
+ SourceManifest,
29
+ field,
30
+ is_float,
31
+ )
32
+ from .sources.source import Source
33
+ from .tasks import TaskStepId, progress
34
+ from .utils import get_dataset_output_dir, log, open_file
35
+
36
+
37
+ def create_dataset(
38
+ namespace: str,
39
+ dataset_name: str,
40
+ source_config: Source,
41
+ ) -> Dataset:
42
+ """Load a dataset from a given source configuration."""
43
+ process_source(data_path(), namespace, dataset_name, source_config)
44
+ return get_dataset(namespace, dataset_name)
45
+
46
+
47
+ def process_source(base_dir: Union[str, pathlib.Path],
48
+ namespace: str,
49
+ dataset_name: str,
50
+ source: Source,
51
+ task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
52
+ """Process a source."""
53
+ output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)
54
+
55
+ source.setup()
56
+ source_schema = source.source_schema()
57
+ items = source.process()
58
+
59
+ # Add UUIDs and fix NaN in string columns.
60
+ items = normalize_items(items, source_schema.fields)
61
+
62
+ # Add progress.
63
+ items = progress(
64
+ items,
65
+ task_step_id=task_step_id,
66
+ estimated_len=source_schema.num_items,
67
+ step_description=f'Reading from source {source.name}...')
68
+
69
+ # Filter out the `None`s after progress.
70
+ items = (item for item in items if item is not None)
71
+
72
+ data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
73
+ filepath, num_items = write_items_to_parquet(
74
+ items=items,
75
+ output_dir=output_dir,
76
+ schema=data_schema,
77
+ filename_prefix=PARQUET_FILENAME_PREFIX,
78
+ shard_index=0,
79
+ num_shards=1)
80
+
81
+ filenames = [os.path.basename(filepath)]
82
+ manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
83
+ with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
84
+ f.write(manifest.json(indent=2, exclude_none=True))
85
+ log(f'Dataset "{dataset_name}" written to {output_dir}')
86
+
87
+ return output_dir, num_items
88
+
89
+
90
+ def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
91
+ """Sanitize items by removing NaNs and NaTs."""
92
+ replace_nan_fields = [
93
+ field_name for field_name, field in fields.items() if field.dtype and not is_float(field.dtype)
94
+ ]
95
+ for item in items:
96
+ if item is None:
97
+ yield item
98
+ continue
99
+
100
+ # Add row uuid if it doesn't exist.
101
+ if UUID_COLUMN not in item:
102
+ item[UUID_COLUMN] = uuid.uuid4().hex
103
+
104
+ # Fix NaN values.
105
+ for field_name in replace_nan_fields:
106
+ item_value = item.get(field_name)
107
+ if item_value and pd.isna(item_value):
108
+ item[field_name] = None
109
+
110
+ yield item
lilac/db_manager.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manages mapping the dataset name to the database instance."""
2
+ import os
3
+ import threading
4
+ from typing import Type
5
+
6
+ from .data.dataset import Dataset
7
+
8
+ _DEFAULT_DATASET_CLS: Type[Dataset]
9
+
10
+ _CACHED_DATASETS: dict[str, Dataset] = {}
11
+
12
+ _db_lock = threading.Lock()
13
+
14
+
15
+ def get_dataset(namespace: str, dataset_name: str) -> Dataset:
16
+ """Get the dataset instance."""
17
+ if not _DEFAULT_DATASET_CLS:
18
+ raise ValueError('Default dataset class not set.')
19
+ cache_key = f'{namespace}/{dataset_name}'
20
+ # https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
21
+ inside_test = 'PYTEST_CURRENT_TEST' in os.environ
22
+ with _db_lock:
23
+ if cache_key not in _CACHED_DATASETS or inside_test:
24
+ _CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS(
25
+ namespace=namespace, dataset_name=dataset_name)
26
+ return _CACHED_DATASETS[cache_key]
27
+
28
+
29
+ def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None:
30
+ """Remove the dataset from the db manager cache."""
31
+ cache_key = f'{namespace}/{dataset_name}'
32
+ with _db_lock:
33
+ if cache_key in _CACHED_DATASETS:
34
+ del _CACHED_DATASETS[cache_key]
35
+
36
+
37
+ # TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a
38
+ # circular dependency.
39
+ def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None:
40
+ """Set the default dataset class."""
41
+ global _DEFAULT_DATASET_CLS
42
+ _DEFAULT_DATASET_CLS = dataset_cls
lilac/embeddings/__init__.py ADDED
File without changes
lilac/embeddings/cohere.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cohere embeddings."""
2
+ from typing import TYPE_CHECKING, Iterable, cast
3
+
4
+ import numpy as np
5
+ from typing_extensions import override
6
+
7
+ from ..env import env
8
+ from ..schema import Item, RichData
9
+ from ..signals.signal import TextEmbeddingSignal
10
+ from ..signals.splitters.chunk_splitter import split_text
11
+ from .embedding import compute_split_embeddings
12
+
13
+ if TYPE_CHECKING:
14
+ from cohere import Client
15
+
16
+ NUM_PARALLEL_REQUESTS = 10
17
+ COHERE_BATCH_SIZE = 96
18
+
19
+
20
+ class Cohere(TextEmbeddingSignal):
21
+ """Computes embeddings using Cohere's embedding API.
22
+
23
+ <br>**Important**: This will send data to an external server!
24
+
25
+ <br>To use this signal, you must get a Cohere API key from
26
+ [cohere.com/embed](https://cohere.com/embed) and add it to your .env.local.
27
+
28
+ <br>For details on pricing, see: https://cohere.com/pricing.
29
+ """
30
+
31
+ name = 'cohere'
32
+ display_name = 'Cohere Embeddings'
33
+
34
+ _model: 'Client'
35
+
36
+ @override
37
+ def setup(self) -> None:
38
+ """Validate that the api key and python package exists in environment."""
39
+ api_key = env('COHERE_API_KEY')
40
+ if not api_key:
41
+ raise ValueError('`COHERE_API_KEY` environment variable not set.')
42
+ try:
43
+ import cohere
44
+ self._model = cohere.Client(api_key, max_retries=10)
45
+ except ImportError:
46
+ raise ImportError('Could not import the "cohere" python package. '
47
+ 'Please install it with `pip install cohere`.')
48
+
49
+ @override
50
+ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
51
+ """Compute embeddings for the given documents."""
52
+
53
+ def embed_fn(texts: list[str]) -> list[np.ndarray]:
54
+ return self._model.embed(texts, truncate='END').embeddings
55
+
56
+ docs = cast(Iterable[str], docs)
57
+ split_fn = split_text if self._split else None
58
+ yield from compute_split_embeddings(
59
+ docs, COHERE_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
lilac/embeddings/default_vector_stores.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Registers all vector stores."""
2
+ from .vector_store import register_vector_store
3
+ from .vector_store_hnsw import HNSWVectorStore
4
+ from .vector_store_numpy import NumpyVectorStore
5
+
6
+
7
+ def register_default_vector_stores() -> None:
8
+ """Register all the default vector stores."""
9
+ register_vector_store(HNSWVectorStore)
10
+ register_vector_store(NumpyVectorStore)
lilac/embeddings/embedding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding registry."""
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import Callable, Generator, Iterable, Iterator, Optional, Union, cast
4
+
5
+ import numpy as np
6
+ from pydantic import StrictStr
7
+ from sklearn.preprocessing import normalize
8
+
9
+ from ..schema import (
10
+ EMBEDDING_KEY,
11
+ TEXT_SPAN_END_FEATURE,
12
+ TEXT_SPAN_START_FEATURE,
13
+ VALUE_KEY,
14
+ Item,
15
+ RichData,
16
+ SpanVector,
17
+ lilac_embedding,
18
+ )
19
+ from ..signals.signal import TextEmbeddingSignal, get_signal_by_type
20
+ from ..signals.splitters.chunk_splitter import TextChunk
21
+ from ..utils import chunks
22
+
23
+ EmbeddingId = Union[StrictStr, TextEmbeddingSignal]
24
+
25
+ EmbedFn = Callable[[Iterable[RichData]], Iterator[list[SpanVector]]]
26
+
27
+
28
+ def get_embed_fn(embedding_name: str, split: bool) -> EmbedFn:
29
+ """Return a function that returns the embedding matrix for the given embedding signal."""
30
+ embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal)
31
+ embedding = embedding_cls(split=split)
32
+ embedding.setup()
33
+
34
+ def _embed_fn(data: Iterable[RichData]) -> Iterator[list[SpanVector]]:
35
+ items = embedding.compute(data)
36
+
37
+ for item in items:
38
+ if not item:
39
+ raise ValueError('Embedding signal returned None.')
40
+
41
+ yield [{
42
+ 'vector': item_val[EMBEDDING_KEY].reshape(-1),
43
+ 'span':
44
+ (item_val[VALUE_KEY][TEXT_SPAN_START_FEATURE], item_val[VALUE_KEY][TEXT_SPAN_END_FEATURE])
45
+ } for item_val in item]
46
+
47
+ return _embed_fn
48
+
49
+
50
+ def compute_split_embeddings(docs: Iterable[str],
51
+ batch_size: int,
52
+ embed_fn: Callable[[list[str]], list[np.ndarray]],
53
+ split_fn: Optional[Callable[[str], list[TextChunk]]] = None,
54
+ num_parallel_requests: int = 1) -> Generator[Item, None, None]:
55
+ """Compute text embeddings in batches of chunks, using the provided splitter and embedding fn."""
56
+ pool = ThreadPoolExecutor()
57
+
58
+ def _splitter(doc: str) -> list[TextChunk]:
59
+ if not doc:
60
+ return []
61
+ if split_fn:
62
+ return split_fn(doc)
63
+ else:
64
+ # Return a single chunk that spans the entire document.
65
+ return [(doc, (0, len(doc)))]
66
+
67
+ num_docs = 0
68
+
69
+ def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
70
+ """Split a batch of documents into chunks and yield them."""
71
+ nonlocal num_docs
72
+ for i, doc in enumerate(docs):
73
+ num_docs += 1
74
+ chunks = _splitter(doc)
75
+ for chunk in chunks:
76
+ yield (i, chunk)
77
+
78
+ doc_chunks = _flat_split_batch_docs(docs)
79
+ items_to_yield: Optional[list[Item]] = None
80
+ current_index = 0
81
+
82
+ mega_batch_size = batch_size * num_parallel_requests
83
+
84
+ for batch in chunks(doc_chunks, mega_batch_size):
85
+ texts = [text for _, (text, _) in batch]
86
+ embeddings: list[np.ndarray] = []
87
+
88
+ for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
89
+ embeddings.extend(x)
90
+ matrix = cast(np.ndarray, normalize(np.array(embeddings, dtype=np.float32)))
91
+ # np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
92
+ embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
93
+ for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
94
+ embedding = embedding.reshape(-1)
95
+ if index == current_index:
96
+ if items_to_yield is None:
97
+ items_to_yield = []
98
+ items_to_yield.append(lilac_embedding(start, end, embedding))
99
+ else:
100
+ yield items_to_yield
101
+ current_index += 1
102
+ while current_index < index:
103
+ yield None
104
+ current_index += 1
105
+ items_to_yield = [lilac_embedding(start, end, embedding)]
106
+
107
+ while current_index < num_docs:
108
+ yield items_to_yield
109
+ items_to_yield = None
110
+ current_index += 1
lilac/embeddings/gte.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
2
+ from typing import TYPE_CHECKING, Iterable, cast
3
+
4
+ from typing_extensions import override
5
+
6
+ from ..schema import Item, RichData
7
+ from ..signals.signal import TextEmbeddingSignal
8
+ from ..signals.splitters.chunk_splitter import split_text
9
+ from .embedding import compute_split_embeddings
10
+ from .transformer_utils import get_model
11
+
12
+ if TYPE_CHECKING:
13
+ pass
14
+
15
+ # See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
16
+ GTE_SMALL = 'thenlper/gte-small'
17
+ GTE_BASE = 'thenlper/gte-base'
18
+
19
+ # Maps a tuple of model name and device to the optimal batch size, found empirically.
20
+ _OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
21
+ GTE_SMALL: {
22
+ '': 64, # Default batch size.
23
+ 'mps': 256,
24
+ },
25
+ GTE_BASE: {
26
+ '': 64, # Default batch size.
27
+ 'mps': 128,
28
+ }
29
+ }
30
+
31
+
32
+ class GTESmall(TextEmbeddingSignal):
33
+ """Computes Gegeral Text Embeddings (GTE).
34
+
35
+ <br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-small)
36
+ for details.
37
+ """
38
+
39
+ name = 'gte-small'
40
+ display_name = 'Gegeral Text Embeddings (small)'
41
+
42
+ _model_name = GTE_SMALL
43
+
44
+ @override
45
+ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
46
+ """Call the embedding function."""
47
+ batch_size, model = get_model(self._model_name, _OPTIMAL_BATCH_SIZES[self._model_name])
48
+ embed_fn = model.encode
49
+ split_fn = split_text if self._split else None
50
+ docs = cast(Iterable[str], docs)
51
+ yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
52
+
53
+
54
+ class GTEBase(GTESmall):
55
+ """Computes Gegeral Text Embeddings (GTE).
56
+
57
+ <br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-base)
58
+ for details.
59
+ """
60
+ name = 'gte-base'
61
+ display_name = 'Gegeral Text Embeddings (base)'
62
+
63
+ _model_name = GTE_BASE
lilac/embeddings/openai.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI embeddings."""
2
+ from typing import TYPE_CHECKING, Any, Iterable, cast
3
+
4
+ import numpy as np
5
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
6
+ from typing_extensions import override
7
+
8
+ from ..env import env
9
+ from ..schema import Item, RichData
10
+ from ..signals.signal import TextEmbeddingSignal
11
+ from ..signals.splitters.chunk_splitter import split_text
12
+ from .embedding import compute_split_embeddings
13
+
14
+ if TYPE_CHECKING:
15
+ import openai
16
+
17
+ NUM_PARALLEL_REQUESTS = 10
18
+ OPENAI_BATCH_SIZE = 128
19
+ EMBEDDING_MODEL = 'text-embedding-ada-002'
20
+
21
+
22
+ class OpenAI(TextEmbeddingSignal):
23
+ """Computes embeddings using OpenAI's embedding API.
24
+
25
+ <br>**Important**: This will send data to an external server!
26
+
27
+ <br>To use this signal, you must get an OpenAI API key from
28
+ [platform.openai.com](https://platform.openai.com/) and add it to your .env.local.
29
+
30
+ <br>For details on pricing, see: https://openai.com/pricing.
31
+ """
32
+
33
+ name = 'openai'
34
+ display_name = 'OpenAI Embeddings'
35
+
36
+ _model: type['openai.Embedding']
37
+
38
+ @override
39
+ def setup(self) -> None:
40
+ api_key = env('OPENAI_API_KEY')
41
+ if not api_key:
42
+ raise ValueError('`OPENAI_API_KEY` environment variable not set.')
43
+ try:
44
+ import openai
45
+ openai.api_key = api_key
46
+ self._model = openai.Embedding
47
+ except ImportError:
48
+ raise ImportError('Could not import the "openai" python package. '
49
+ 'Please install it with `pip install openai`.')
50
+
51
+ @override
52
+ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
53
+ """Compute embeddings for the given documents."""
54
+
55
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
56
+ def embed_fn(texts: list[str]) -> list[np.ndarray]:
57
+
58
+ # Replace newlines, which can negatively affect performance.
59
+ # See https://github.com/search?q=repo%3Aopenai%2Fopenai-python+replace+newlines&type=code
60
+ texts = [text.replace('\n', ' ') for text in texts]
61
+
62
+ response: Any = self._model.create(input=texts, model=EMBEDDING_MODEL)
63
+ return [np.array(embedding['embedding'], dtype=np.float32) for embedding in response['data']]
64
+
65
+ docs = cast(Iterable[str], docs)
66
+ split_fn = split_text if self._split else None
67
+ yield from compute_split_embeddings(
68
+ docs, OPENAI_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
lilac/embeddings/palm.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PaLM embeddings."""
2
+ from typing import TYPE_CHECKING, Iterable, cast
3
+
4
+ import numpy as np
5
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
6
+ from typing_extensions import override
7
+
8
+ from ..env import env
9
+ from ..schema import Item, RichData
10
+ from ..signals.signal import TextEmbeddingSignal
11
+ from ..signals.splitters.chunk_splitter import split_text
12
+ from .embedding import compute_split_embeddings
13
+
14
+ if TYPE_CHECKING:
15
+ import google.generativeai as palm
16
+
17
+ PALM_BATCH_SIZE = 1 # PaLM API only supports batch size 1.
18
+ NUM_PARALLEL_REQUESTS = 256 # Because batch size is 1, we can send many requests in parallel.
19
+ EMBEDDING_MODEL = 'models/embedding-gecko-001'
20
+
21
+
22
+ class PaLM(TextEmbeddingSignal):
23
+ """Computes embeddings using PaLM's embedding API.
24
+
25
+ <br>**Important**: This will send data to an external server!
26
+
27
+ <br>To use this signal, you must get a PaLM API key from
28
+ [makersuite.google.com](https://makersuite.google.com/app/apikey) and add it to your .env.local.
29
+ """
30
+
31
+ name = 'palm'
32
+ display_name = 'PaLM Embeddings'
33
+
34
+ _model: 'palm.generate_embeddings'
35
+
36
+ @override
37
+ def setup(self) -> None:
38
+ api_key = env('PALM_API_KEY')
39
+ if not api_key:
40
+ raise ValueError('`PALM_API_KEY` environment variable not set.')
41
+ try:
42
+ import google.generativeai as palm
43
+ palm.configure(api_key=api_key)
44
+ self._model = palm.generate_embeddings
45
+ except ImportError:
46
+ raise ImportError('Could not import the "google.generativeai" python package. '
47
+ 'Please install it with `pip install google-generativeai`.')
48
+
49
+ @override
50
+ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
51
+ """Compute embeddings for the given documents."""
52
+
53
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
54
+ def embed_fn(texts: list[str]) -> list[np.ndarray]:
55
+ assert len(texts) == 1, 'PaLM API only supports batch size 1.'
56
+ response = self._model(model=EMBEDDING_MODEL, text=texts[0])
57
+ return [np.array(response['embedding'], dtype=np.float32)]
58
+
59
+ docs = cast(Iterable[str], docs)
60
+ split_fn = split_text if self._split else None
61
+ yield from compute_split_embeddings(
62
+ docs, PALM_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
lilac/embeddings/sbert.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sentence-BERT embeddings. Open-source models, designed to run on device."""
2
+ from typing import Iterable, cast
3
+
4
+ from typing_extensions import override
5
+
6
+ from ..schema import Item, RichData
7
+ from ..signals.signal import TextEmbeddingSignal
8
+ from ..signals.splitters.chunk_splitter import split_text
9
+ from .embedding import compute_split_embeddings
10
+ from .transformer_utils import get_model
11
+
12
+ # The `all-mpnet-base-v2` model provides the best quality, while `all-MiniLM-L6-v2`` is 5 times
13
+ # faster and still offers good quality. See https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/
14
+ MINI_LM_MODEL = 'all-MiniLM-L6-v2'
15
+
16
+ # Maps a tuple of model name and device to the optimal batch size, found empirically.
17
+ _OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
18
+ MINI_LM_MODEL: {
19
+ '': 64, # Default batch size.
20
+ 'mps': 256,
21
+ }
22
+ }
23
+
24
+
25
+ class SBERT(TextEmbeddingSignal):
26
+ """Computes embeddings using Sentence-BERT library."""
27
+
28
+ name = 'sbert'
29
+ display_name = 'SBERT Embeddings'
30
+
31
+ @override
32
+ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
33
+ """Call the embedding function."""
34
+ batch_size, model = get_model(MINI_LM_MODEL, _OPTIMAL_BATCH_SIZES[MINI_LM_MODEL])
35
+ embed_fn = model.encode
36
+ split_fn = split_text if self._split else None
37
+ docs = cast(Iterable[str], docs)
38
+ yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
lilac/embeddings/transformer_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for transformer embeddings."""
2
+
3
+ import functools
4
+ import os
5
+ from typing import TYPE_CHECKING, Optional
6
+
7
+ from ..env import data_path
8
+ from ..utils import log
9
+
10
+ if TYPE_CHECKING:
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+
14
+ def get_model(model_name: str,
15
+ optimal_batch_sizes: dict[str, int] = {}) -> tuple[int, 'SentenceTransformer']:
16
+ """Get a transformer model and the optimal batch size for it."""
17
+ try:
18
+ import torch.backends.mps
19
+ from sentence_transformers import SentenceTransformer
20
+ except ImportError:
21
+ raise ImportError('Could not import the "sentence_transformers" python package. '
22
+ 'Please install it with `pip install sentence-transformers`.')
23
+ preferred_device: Optional[str] = None
24
+ if torch.backends.mps.is_available():
25
+ preferred_device = 'mps'
26
+ elif not torch.backends.mps.is_built():
27
+ log('MPS not available because the current PyTorch install was not built with MPS enabled.')
28
+
29
+ @functools.cache
30
+ def _get_model(model_name: str) -> 'SentenceTransformer':
31
+ return SentenceTransformer(
32
+ model_name, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))
33
+
34
+ batch_size = optimal_batch_sizes[preferred_device or '']
35
+ return batch_size, _get_model(model_name)
lilac/embeddings/vector_store.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interface for storing vectors."""
2
+
3
+ import abc
4
+ import os
5
+ import pickle
6
+ from typing import Iterable, Optional, Type
7
+
8
+ import numpy as np
9
+
10
+ from ..schema import SpanVector, VectorKey
11
+ from ..utils import open_file
12
+
13
+
14
+ class VectorStore(abc.ABC):
15
+ """Interface for storing and retrieving vectors."""
16
+
17
+ # The global name of the vector store.
18
+ name: str
19
+
20
+ @abc.abstractmethod
21
+ def save(self, base_path: str) -> None:
22
+ """Save the store to disk."""
23
+ pass
24
+
25
+ @abc.abstractmethod
26
+ def load(self, base_path: str) -> None:
27
+ """Load the store from disk."""
28
+ pass
29
+
30
+ @abc.abstractmethod
31
+ def size(self) -> int:
32
+ """Return the number of vectors in the store."""
33
+ pass
34
+
35
+ @abc.abstractmethod
36
+ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
37
+ """Add or edit the given keyed embeddings to the store.
38
+
39
+ If the keys already exist they will be overwritten, acting as an "upsert".
40
+
41
+ Args:
42
+ keys: The keys to add the embeddings for.
43
+ embeddings: The embeddings to add. This should be a 2D matrix with the same length as keys.
44
+ """
45
+ pass
46
+
47
+ @abc.abstractmethod
48
+ def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
49
+ """Return the embeddings for given keys.
50
+
51
+ Args:
52
+ keys: The keys to return the embeddings for. If None, return all embeddings.
53
+
54
+ Returns
55
+ The embeddings for the given keys.
56
+ """
57
+ pass
58
+
59
+ def topk(self,
60
+ query: np.ndarray,
61
+ k: int,
62
+ keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
63
+ """Return the top k most similar vectors.
64
+
65
+ Args:
66
+ query: The query vector.
67
+ k: The number of results to return.
68
+ keys: Optional keys to restrict the search to.
69
+
70
+ Returns
71
+ A list of (key, score) tuples.
72
+ """
73
+ raise NotImplementedError
74
+
75
+
76
+ PathKey = VectorKey
77
+
78
+ _SPANS_PICKLE_NAME = 'spans.pkl'
79
+
80
+
81
+ class VectorDBIndex:
82
+ """Stores and retrives span vectors.
83
+
84
+ This wraps a regular vector store by adding a mapping from path keys, such as (uuid1, 0),
85
+ to span keys, such as (uuid1, 0, 0), which denotes the first span in the (uuid1, 0) text document.
86
+ """
87
+
88
+ def __init__(self, vector_store: str) -> None:
89
+ self._vector_store: VectorStore = get_vector_store_cls(vector_store)()
90
+ # Map a path key to spans for that path.
91
+ self._id_to_spans: dict[PathKey, list[tuple[int, int]]] = {}
92
+
93
+ def load(self, base_path: str) -> None:
94
+ """Load the vector index from disk."""
95
+ assert not self._id_to_spans, 'Cannot load into a non-empty index.'
96
+ with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'rb') as f:
97
+ self._id_to_spans.update(pickle.load(f))
98
+ self._vector_store.load(os.path.join(base_path, self._vector_store.name))
99
+
100
+ def save(self, base_path: str) -> None:
101
+ """Save the vector index to disk."""
102
+ assert self._id_to_spans, 'Cannot save an empty index.'
103
+ with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'wb') as f:
104
+ pickle.dump(list(self._id_to_spans.items()), f)
105
+ self._vector_store.save(os.path.join(base_path, self._vector_store.name))
106
+
107
+ def add(self, spans: list[tuple[PathKey, list[tuple[int, int]]]], embeddings: np.ndarray) -> None:
108
+ """Add the given spans and embeddings.
109
+
110
+ Args:
111
+ spans: The spans to initialize the index with.
112
+ embeddings: The embeddings to initialize the index with.
113
+ """
114
+ assert not self._id_to_spans, 'Cannot add to a non-empty index.'
115
+ self._id_to_spans.update(spans)
116
+ vector_keys = [(*path_key, i) for path_key, spans in spans for i in range(len(spans))]
117
+ assert len(vector_keys) == len(embeddings), (
118
+ f'Number of spans ({len(vector_keys)}) and embeddings ({len(embeddings)}) must match.')
119
+ self._vector_store.add(vector_keys, embeddings)
120
+
121
+ def get_vector_store(self) -> VectorStore:
122
+ """Return the underlying vector store."""
123
+ return self._vector_store
124
+
125
+ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
126
+ """Return the spans with vectors for each key in `keys`.
127
+
128
+ Args:
129
+ keys: The keys to return the vectors for.
130
+
131
+ Returns
132
+ The span vectors for the given keys.
133
+ """
134
+ all_spans: list[list[tuple[int, int]]] = []
135
+ vector_keys: list[VectorKey] = []
136
+ for path_key in keys:
137
+ spans = self._id_to_spans[path_key]
138
+ all_spans.append(spans)
139
+ vector_keys.extend([(*path_key, i) for i in range(len(spans))])
140
+
141
+ all_vectors = self._vector_store.get(vector_keys)
142
+ offset = 0
143
+ for spans in all_spans:
144
+ vectors = all_vectors[offset:offset + len(spans)]
145
+ yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
146
+ offset += len(spans)
147
+
148
+ def topk(self,
149
+ query: np.ndarray,
150
+ k: int,
151
+ path_keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, float]]:
152
+ """Return the top k most similar vectors.
153
+
154
+ Args:
155
+ query: The query vector.
156
+ k: The number of results to return.
157
+ path_keys: Optional key prefixes to restrict the search to.
158
+
159
+ Returns
160
+ A list of (key, score) tuples.
161
+ """
162
+ span_keys: Optional[list[VectorKey]] = None
163
+ if path_keys is not None:
164
+ span_keys = [
165
+ (*path_key, i) for path_key in path_keys for i in range(len(self._id_to_spans[path_key]))
166
+ ]
167
+ span_k = k
168
+ path_key_scores: dict[PathKey, float] = {}
169
+ total_num_span_keys = self._vector_store.size()
170
+ while (len(path_key_scores) < k and span_k < total_num_span_keys and
171
+ (not span_keys or span_k < len(span_keys))):
172
+ span_k += k
173
+ vector_key_scores = self._vector_store.topk(query, span_k, span_keys)
174
+ for (*path_key_list, _), score in vector_key_scores:
175
+ path_key = tuple(path_key_list)
176
+ if path_key not in path_key_scores:
177
+ path_key_scores[path_key] = score
178
+
179
+ return list(path_key_scores.items())[:k]
180
+
181
+
182
+ VECTOR_STORE_REGISTRY: dict[str, Type[VectorStore]] = {}
183
+
184
+
185
+ def register_vector_store(vector_store_cls: Type[VectorStore]) -> None:
186
+ """Register a vector store in the global registry."""
187
+ if vector_store_cls.name in VECTOR_STORE_REGISTRY:
188
+ raise ValueError(f'Vector store "{vector_store_cls.name}" has already been registered!')
189
+
190
+ VECTOR_STORE_REGISTRY[vector_store_cls.name] = vector_store_cls
191
+
192
+
193
+ def get_vector_store_cls(vector_store_name: str) -> Type[VectorStore]:
194
+ """Return a registered vector store given the name in the registry."""
195
+ return VECTOR_STORE_REGISTRY[vector_store_name]
196
+
197
+
198
+ def clear_vector_store_registry() -> None:
199
+ """Clear the vector store registry."""
200
+ VECTOR_STORE_REGISTRY.clear()
lilac/embeddings/vector_store_hnsw.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HNSW vector store."""
2
+
3
+ import multiprocessing
4
+ from typing import Iterable, Optional, Set, cast
5
+
6
+ import hnswlib
7
+ import numpy as np
8
+ import pandas as pd
9
+ from typing_extensions import override
10
+
11
+ from ..schema import VectorKey
12
+ from ..utils import DebugTimer
13
+ from .vector_store import VectorStore
14
+
15
+ _HNSW_SUFFIX = '.hnswlib.bin'
16
+ _LOOKUP_SUFFIX = '.lookup.pkl'
17
+
18
+
19
+ class HNSWVectorStore(VectorStore):
20
+ """HNSW-backed vector store."""
21
+
22
+ name = 'hnsw'
23
+
24
+ def __init__(self) -> None:
25
+ # Maps a `VectorKey` to a row index in `_embeddings`.
26
+ self._key_to_label: Optional[pd.Series] = None
27
+ self._index: Optional[hnswlib.Index] = None
28
+
29
+ @override
30
+ def save(self, base_path: str) -> None:
31
+ assert self._key_to_label is not None and self._index is not None, (
32
+ 'The vector store has no embeddings. Call load() or add() first.')
33
+ self._index.save_index(base_path + _HNSW_SUFFIX)
34
+ self._key_to_label.to_pickle(base_path + _LOOKUP_SUFFIX)
35
+
36
+ @override
37
+ def load(self, base_path: str) -> None:
38
+ self._key_to_label = pd.read_pickle(base_path + _LOOKUP_SUFFIX)
39
+ dim = int(self._key_to_label.name)
40
+ index = hnswlib.Index(space='ip', dim=dim)
41
+ index.set_ef(10)
42
+ index.set_num_threads(multiprocessing.cpu_count())
43
+ index.load_index(base_path + _HNSW_SUFFIX)
44
+ self._index = index
45
+
46
+ @override
47
+ def size(self) -> int:
48
+ assert self._index is not None, (
49
+ 'The vector store has no embeddings. Call load() or add() first.')
50
+ return self._index.get_current_count()
51
+
52
+ @override
53
+ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
54
+ assert self._index is None, (
55
+ 'Embeddings already exist in this store. Upsert is not yet supported.')
56
+
57
+ if len(keys) != embeddings.shape[0]:
58
+ raise ValueError(
59
+ f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')
60
+
61
+ dim = embeddings.shape[1]
62
+ with DebugTimer('hnswlib index creation'):
63
+ index = hnswlib.Index(space='ip', dim=dim)
64
+ index.set_ef(10)
65
+ index.set_num_threads(multiprocessing.cpu_count())
66
+ index.init_index(max_elements=len(keys), ef_construction=50, M=16)
67
+
68
+ # Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
69
+ # than float64.
70
+ embeddings = embeddings.astype(np.float32)
71
+ row_indices = np.arange(len(keys), dtype=np.int32)
72
+ self._key_to_label = pd.Series(row_indices, index=keys, dtype=np.int32)
73
+ self._key_to_label.name = str(dim)
74
+ index.add_items(embeddings, row_indices)
75
+ self._index = index
76
+
77
+ @override
78
+ def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
79
+ assert self._index is not None and self._key_to_label is not None, (
80
+ 'No embeddings exist in this store.')
81
+ if not keys:
82
+ return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32)
83
+ locs = self._key_to_label.loc[cast(list[str], keys)].values
84
+ return np.array(self._index.get_items(locs), dtype=np.float32)
85
+
86
+ @override
87
+ def topk(self,
88
+ query: np.ndarray,
89
+ k: int,
90
+ keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
91
+ assert self._index is not None and self._key_to_label is not None, (
92
+ 'No embeddings exist in this store.')
93
+ labels: Set[int] = set()
94
+ if keys is not None:
95
+ labels = set(self._key_to_label.loc[cast(list[str], keys)].tolist())
96
+ k = min(k, len(labels))
97
+
98
+ def filter_func(label: int) -> bool:
99
+ return label in labels
100
+
101
+ query = np.expand_dims(query.astype(np.float32), axis=0)
102
+ locs, dists = self._index.knn_query(query, k=k, filter=filter_func if labels else None)
103
+ locs = locs[0]
104
+ dists = dists[0]
105
+ topk_keys = self._key_to_label.index.values[locs]
106
+ return [(key, 1 - dist) for key, dist in zip(topk_keys, dists)]
lilac/embeddings/vector_store_numpy.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NumpyVectorStore class for storing vectors in numpy arrays."""
2
+
3
+ from typing import Iterable, Optional, cast
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing_extensions import override
8
+
9
+ from ..schema import VectorKey
10
+ from .vector_store import VectorStore
11
+
12
+ _EMBEDDINGS_SUFFIX = '.matrix.npy'
13
+ _LOOKUP_SUFFIX = '.lookup.pkl'
14
+
15
+
16
+ class NumpyVectorStore(VectorStore):
17
+ """Stores vectors as in-memory np arrays."""
18
+ name = 'numpy'
19
+
20
+ def __init__(self) -> None:
21
+ self._embeddings: Optional[np.ndarray] = None
22
+ # Maps a `VectorKey` to a row index in `_embeddings`.
23
+ self._key_to_index: Optional[pd.Series] = None
24
+
25
+ @override
26
+ def size(self) -> int:
27
+ assert self._embeddings is not None, (
28
+ 'The vector store has no embeddings. Call load() or add() first.')
29
+ return len(self._embeddings)
30
+
31
+ @override
32
+ def save(self, base_path: str) -> None:
33
+ assert self._embeddings is not None and self._key_to_index is not None, (
34
+ 'The vector store has no embeddings. Call load() or add() first.')
35
+ np.save(base_path + _EMBEDDINGS_SUFFIX, self._embeddings, allow_pickle=False)
36
+ self._key_to_index.to_pickle(base_path + _LOOKUP_SUFFIX)
37
+
38
+ @override
39
+ def load(self, base_path: str) -> None:
40
+ self._embeddings = np.load(base_path + _EMBEDDINGS_SUFFIX, allow_pickle=False)
41
+ self._key_to_index = pd.read_pickle(base_path + _LOOKUP_SUFFIX)
42
+
43
+ @override
44
+ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
45
+ if self._embeddings or self._key_to_index:
46
+ raise ValueError('Embeddings already exist in this store. Upsert is not yet supported.')
47
+
48
+ if len(keys) != embeddings.shape[0]:
49
+ raise ValueError(
50
+ f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')
51
+
52
+ # Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
53
+ # than float64.
54
+ self._embeddings = embeddings.astype(np.float32)
55
+ row_indices = np.arange(len(embeddings), dtype=np.uint32)
56
+ self._key_to_index = pd.Series(row_indices, index=keys, dtype=np.uint32)
57
+
58
+ @override
59
+ def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
60
+ assert self._embeddings is not None and self._key_to_index is not None, (
61
+ 'The vector store has no embeddings. Call load() or add() first.')
62
+ if not keys:
63
+ return self._embeddings
64
+ locs = self._key_to_index.loc[cast(list[str], keys)]
65
+ return self._embeddings.take(locs, axis=0)
66
+
67
+ @override
68
+ def topk(self,
69
+ query: np.ndarray,
70
+ k: int,
71
+ keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
72
+ assert self._embeddings is not None and self._key_to_index is not None, (
73
+ 'The vector store has no embeddings. Call load() or add() first.')
74
+ if keys is not None:
75
+ row_indices = self._key_to_index.loc[cast(list[str], keys)]
76
+ embeddings = self._embeddings.take(row_indices, axis=0)
77
+ keys = list(keys)
78
+ else:
79
+ keys, embeddings = cast(list[VectorKey], self._key_to_index.index.tolist()), self._embeddings
80
+
81
+ query = query.astype(embeddings.dtype)
82
+ similarities: np.ndarray = np.dot(embeddings, query).reshape(-1)
83
+ k = min(k, len(similarities))
84
+
85
+ # We do a partition + sort only top K to save time: O(n + klogk) instead of O(nlogn).
86
+ indices = np.argpartition(similarities, -k)[-k:]
87
+ # Indices sorted by value from largest to smallest.
88
+ indices = indices[np.argsort(similarities[indices])][::-1]
89
+
90
+ topk_similarities = similarities[indices]
91
+ topk_keys = [keys[idx] for idx in indices]
92
+ return list(zip(topk_keys, topk_similarities))
lilac/env.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load environment variables from .env file."""
2
+ import os
3
+ from typing import Any, Literal, Optional, Union, cast
4
+
5
+ from dotenv import load_dotenv
6
+
7
+ EnvironmentKeys = Union[Literal['LILAC_DATA_PATH'],
8
+ # Authentication on the demo.
9
+ Literal['LILAC_AUTH_ENABLED'], Literal['GOOGLE_CLIENT_ID'],
10
+ Literal['GOOGLE_CLIENT_SECRET'], Literal['LILAC_OAUTH_SECRET_KEY'],
11
+ # DuckDB accessing GCS.
12
+ Literal['GCS_REGION'], Literal['GCS_ACCESS_KEY'], Literal['GCS_SECRET_KEY'],
13
+ # Embedding API keys.
14
+ Literal['OPENAI_API_KEY'], Literal['COHERE_API_KEY'],
15
+ Literal['PALM_API_KEY'],
16
+ # HuggingFace demos.
17
+ Literal['HF_USERNAME'], Literal['HF_STAGING_DEMO_REPO'],
18
+ Literal['SPACE_ID'], Literal['HF_ACCESS_TOKEN'],
19
+ # DuckDB
20
+ Literal['DUCKDB_USE_VIEWS'],
21
+ # Debugging
22
+ Literal['DEBUG'], Literal['DISABLE_LOGS']]
23
+
24
+
25
+ def _init_env() -> None:
26
+ in_test = os.environ.get('LILAC_TEST', None)
27
+ # Load the .env files into the environment in order of highest to lowest priority.
28
+
29
+ if not in_test: # Skip local environment variables when testing.
30
+ load_dotenv('.env.local')
31
+ load_dotenv('.env.demo')
32
+ load_dotenv('.env')
33
+
34
+ if os.environ.get('LILAC_AUTH_ENABLED', None):
35
+ if not os.environ.get('GOOGLE_CLIENT_ID', None) or not os.environ.get(
36
+ 'GOOGLE_CLIENT_SECRET', None):
37
+ raise ValueError(
38
+ 'Missing `GOOGLE_CLIENT_ID` or `GOOGLE_CLIENT_SECRET` when `LILAC_AUTH_ENABLED=true`')
39
+ SECRET_KEY = os.environ.get('LILAC_OAUTH_SECRET_KEY', None)
40
+ if not SECRET_KEY:
41
+ raise ValueError('Missing `LILAC_OAUTH_SECRET_KEY` when `LILAC_AUTH_ENABLED=true`')
42
+ if os.environ.get('LILAC_AUTH_ENABLED', None):
43
+ if not os.environ.get('GOOGLE_CLIENT_ID', None) or not os.environ.get(
44
+ 'GOOGLE_CLIENT_SECRET', None):
45
+ raise ValueError(
46
+ 'Missing `GOOGLE_CLIENT_ID` or `GOOGLE_CLIENT_SECRET` when `LILAC_AUTH_ENABLED=true`')
47
+ SECRET_KEY = os.environ.get('LILAC_OAUTH_SECRET_KEY', None)
48
+ if not SECRET_KEY:
49
+ raise ValueError('Missing `LILAC_OAUTH_SECRET_KEY` when `LILAC_AUTH_ENABLED=true`')
50
+
51
+
52
+ def env(key: EnvironmentKeys, default: Optional[Any] = None) -> Any:
53
+ """Return the value of an environment variable."""
54
+ return os.environ.get(key, default)
55
+
56
+
57
+ def data_path() -> str:
58
+ """Return the base path for data."""
59
+ return cast(str, env('LILAC_DATA_PATH', './data'))
60
+
61
+
62
+ # Initialize the environment at import time.
63
+ _init_env()
lilac/load.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A script to load a dataset or set of datasets from a config for a Lilac instance.
2
+
3
+ Usage:
4
+
5
+ poetry run python -m lilac.load \
6
+ --output_dir=demo_data \
7
+ --config_path=demo.yml
8
+ """
9
+
10
+ import gc
11
+ import json
12
+ import os
13
+ import pathlib
14
+ import shutil
15
+
16
+ import click
17
+ import dask
18
+ import psutil
19
+ import yaml
20
+ from distributed import Client
21
+
22
+ from .config import Config, EmbeddingConfig, SignalConfig
23
+ from .data_loader import process_source
24
+ from .db_manager import get_dataset
25
+ from .schema import UUID_COLUMN
26
+ from .tasks import TaskManager, TaskStepId
27
+ from .utils import DebugTimer, get_datasets_dir, list_datasets
28
+
29
+
30
+ @click.command()
31
+ @click.option(
32
+ '--output_dir', required=True, type=str, help='The output directory to write files to.')
33
+ @click.option(
34
+ '--config_path',
35
+ required=True,
36
+ type=str,
37
+ help='The path to a json or yml file describing the configuration. '
38
+ 'The file contents should be an instance of `lilac.Config`.')
39
+ @click.option(
40
+ '--overwrite',
41
+ help='When True, runs all all data from scratch, overwriting existing data. When false, only'
42
+ 'load new datasets, embeddings, and signals.',
43
+ type=bool,
44
+ is_flag=True,
45
+ default=False)
46
+ def load_command(output_dir: str, config_path: str, overwrite: bool) -> None:
47
+ """Run the source loader as a binary."""
48
+ load(output_dir, config_path, overwrite)
49
+
50
+
51
+ def load(output_dir: str, config_path: str, overwrite: bool) -> None:
52
+ """Run the source loader as a binary."""
53
+ old_data_path = os.environ.get('LILAC_DATA_PATH')
54
+ os.environ['LILAC_DATA_PATH'] = output_dir
55
+ # Turn off debug logging.
56
+ del os.environ['DEBUG']
57
+
58
+ config_ext = pathlib.Path(config_path).suffix
59
+ if config_ext in ['.yml', '.yaml']:
60
+ with open(config_path, 'r') as f:
61
+ config_dict = yaml.safe_load(f)
62
+ elif config_ext in ['.json']:
63
+ with open(config_path, 'r') as f:
64
+ config_dict = json.load(f)
65
+ else:
66
+ raise ValueError(f'Unsupported config file extension: {config_ext}')
67
+
68
+ config = Config(**config_dict)
69
+
70
+ # Explicitly create a dask client in sync mode.
71
+ dask.config.set({'distributed.worker.daemon': False})
72
+ total_memory_gb = psutil.virtual_memory().total / (1024**3)
73
+ task_manager = TaskManager(Client(memory_limit=f'{total_memory_gb} GB'))
74
+
75
+ if overwrite:
76
+ shutil.rmtree(get_datasets_dir(output_dir), ignore_errors=True)
77
+
78
+ existing_datasets = [f'{d.namespace}/{d.dataset_name}' for d in list_datasets(output_dir)]
79
+
80
+ print()
81
+ print('*** Load datasets ***')
82
+ if overwrite:
83
+ datasets_to_load = config.datasets
84
+ else:
85
+ datasets_to_load = [
86
+ d for d in config.datasets if f'{d.namespace}/{d.name}' not in existing_datasets
87
+ ]
88
+ skipped_datasets = [
89
+ d for d in config.datasets if f'{d.namespace}/{d.name}' in existing_datasets
90
+ ]
91
+ print('Skipping loaded datasets:', ', '.join([d.name for d in skipped_datasets]))
92
+
93
+ with DebugTimer(f'Loading datasets: {", ".join([d.name for d in datasets_to_load])}'):
94
+ for d in datasets_to_load:
95
+ shutil.rmtree(os.path.join(output_dir, d.name), ignore_errors=True)
96
+ task_id = task_manager.task_id(f'Load dataset {d.namespace}/{d.name}')
97
+ task_manager.execute(task_id, process_source, output_dir, d.namespace, d.name, d.source,
98
+ (task_id, 0))
99
+
100
+ task_manager.wait()
101
+
102
+ print()
103
+ total_num_rows = 0
104
+ for d in datasets_to_load:
105
+ num_rows = get_dataset(d.namespace, d.name).select_rows([UUID_COLUMN], limit=1).total_num_rows
106
+ print(f'{d.namespace}/{d.name} loaded with {num_rows:,} rows.')
107
+ gc.collect()
108
+ total_num_rows += num_rows
109
+
110
+ print(f'Done loading {len(datasets_to_load)} datasets with {total_num_rows:,} rows.')
111
+
112
+ print('*** Dataset settings ***')
113
+ for d in config.datasets:
114
+ if d.settings:
115
+ dataset = get_dataset(d.namespace, d.name)
116
+ dataset.update_settings(d.settings)
117
+
118
+ print()
119
+ print('*** Compute embeddings ***')
120
+ with DebugTimer('Loading embeddings'):
121
+ for d in config.datasets:
122
+ # If embeddings are explicitly set, use only those.
123
+ embeddings = d.embeddings or []
124
+ # If embeddings are not explicitly set, use the media paths and preferred embedding from
125
+ # settings.
126
+ if not embeddings:
127
+ if d.settings and d.settings.ui:
128
+ for path in d.settings.ui.media_paths or []:
129
+ if d.settings.preferred_embedding:
130
+ embeddings.append(
131
+ EmbeddingConfig(path=path, embedding=d.settings.preferred_embedding))
132
+ print('emb configs', embeddings)
133
+ for e in embeddings:
134
+ task_id = task_manager.task_id(f'Compute embedding {e.embedding} on {e.path}')
135
+ task_manager.execute(task_id, _compute_embedding, d.namespace, d.name, e, output_dir,
136
+ overwrite, (task_id, 0))
137
+ task_manager.wait()
138
+ exit()
139
+ print()
140
+ print('*** Compute signals ***')
141
+ with DebugTimer('Computing signals'):
142
+ for d in config.datasets:
143
+ # If signals are explicitly set, use only those.
144
+ signals = d.signals or []
145
+ # If signals are not explicitly set, use the media paths and config.signals.
146
+ if not signals:
147
+ if d.settings and d.settings.ui:
148
+ for path in d.settings.ui.media_paths or []:
149
+ for signal in config.signals or []:
150
+ signals.append(SignalConfig(path=path, signal=signal))
151
+
152
+ for s in signals:
153
+ task_id = task_manager.task_id(f'Compute signal {s.signal} on {s.path}')
154
+ task_manager.execute(task_id, _compute_signal, d.namespace, d.name, s, output_dir,
155
+ overwrite, (task_id, 0))
156
+ task_manager.wait()
157
+
158
+ print()
159
+ print('Done!')
160
+
161
+ if old_data_path:
162
+ os.environ['LILAC_DATA_PATH'] = old_data_path
163
+
164
+
165
+ def _compute_signal(namespace: str, name: str, signal_config: SignalConfig, output_dir: str,
166
+ overwrite: bool, task_step_id: TaskStepId) -> None:
167
+ os.environ['LILAC_DATA_PATH'] = output_dir
168
+ # Turn off debug logging.
169
+ if 'DEBUG' in os.environ:
170
+ del os.environ['DEBUG']
171
+
172
+ compute_signal = False
173
+ if overwrite:
174
+ compute_signal = True
175
+
176
+ dataset = get_dataset(namespace, name)
177
+
178
+ if not compute_signal:
179
+ field = dataset.manifest().data_schema.get_field(signal_config.path)
180
+ signal_field = (field.fields or {}).get(signal_config.signal.key())
181
+ if not signal_field or signal_field.signal != signal_config.signal.dict():
182
+ compute_signal = True
183
+ if compute_signal:
184
+ dataset.compute_signal(signal_config.signal, signal_config.path, task_step_id)
185
+
186
+ gc.collect()
187
+
188
+
189
+ def _compute_embedding(namespace: str, name: str, embedding_config: EmbeddingConfig,
190
+ output_dir: str, overwrite: bool, task_step_id: TaskStepId) -> None:
191
+ os.environ['LILAC_DATA_PATH'] = output_dir
192
+ # Turn off debug logging.
193
+ if 'DEBUG' in os.environ:
194
+ del os.environ['DEBUG']
195
+
196
+ compute_embedding = False
197
+ if overwrite:
198
+ compute_embedding = True
199
+
200
+ dataset = get_dataset(namespace, name)
201
+ if not compute_embedding:
202
+ field = dataset.manifest().data_schema.get_field(embedding_config.path)
203
+ embedding_field = (field.fields or {}).get(embedding_config.embedding)
204
+ if not embedding_field:
205
+ compute_embedding = True
206
+
207
+ if compute_embedding:
208
+ dataset.compute_embedding(embedding_config.embedding, embedding_config.path, task_step_id)
209
+
210
+ gc.collect()
211
+
212
+
213
+ if __name__ == '__main__':
214
+ load()
lilac/make_openapi.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Writes the openapi.json file to the specified output.
2
+
3
+ This is meant to run as a standalone script. It lives in lilac/ so we can import the FastAPI app.
4
+ """
5
+ import json
6
+
7
+ import click
8
+ from fastapi.openapi.utils import get_openapi
9
+
10
+ from .server import app
11
+
12
+
13
+ @click.command()
14
+ @click.option(
15
+ '--output', required=True, type=str, help='The output filepath for the opepnapi.json file.')
16
+ def main(output: str) -> None:
17
+ """Create the openapi.json file for the API to generate TypeScript stubs."""
18
+ with open(output, 'w') as f:
19
+ json.dump(
20
+ get_openapi(
21
+ title=app.title,
22
+ version=app.version,
23
+ openapi_version=app.openapi_version,
24
+ description=app.description,
25
+ routes=app.routes), f)
26
+
27
+
28
+ if __name__ == '__main__':
29
+ main()
lilac/parquet_writer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A Parquet file writer that wraps the pyarrow writer."""
2
+ from typing import IO, Optional
3
+
4
+ import pyarrow as pa
5
+ import pyarrow.parquet as pq
6
+
7
+ from .schema import Item, Schema, schema_to_arrow_schema
8
+
9
+
10
+ class ParquetWriter:
11
+ """A writer to parquet."""
12
+
13
+ def __init__(self,
14
+ schema: Schema,
15
+ codec: str = 'snappy',
16
+ row_group_buffer_size: int = 128 * 1024 * 1024,
17
+ record_batch_size: int = 10_000):
18
+ self._schema = schema_to_arrow_schema(schema)
19
+ self._codec = codec
20
+ self._row_group_buffer_size = row_group_buffer_size
21
+ self._buffer: list[list[Optional[Item]]] = [[] for _ in range(len(self._schema.names))]
22
+ self._buffer_size = record_batch_size
23
+ self._record_batches: list[pa.RecordBatch] = []
24
+ self._record_batches_byte_size = 0
25
+ self.writer: pq.ParquetWriter = None
26
+
27
+ def open(self, file_handle: IO) -> None:
28
+ """Open the destination file for writing."""
29
+ self.writer = pq.ParquetWriter(file_handle, self._schema, compression=self._codec)
30
+
31
+ def write(self, record: Item) -> None:
32
+ """Write the record to the destination file."""
33
+ if len(self._buffer[0]) >= self._buffer_size:
34
+ self._flush_buffer()
35
+
36
+ if self._record_batches_byte_size >= self._row_group_buffer_size:
37
+ self._write_batches()
38
+
39
+ # reorder the data in columnar format.
40
+ for i, n in enumerate(self._schema.names):
41
+ self._buffer[i].append(record.get(n))
42
+
43
+ def close(self) -> None:
44
+ """Flushes the write buffer and closes the destination file."""
45
+ if len(self._buffer[0]) > 0:
46
+ self._flush_buffer()
47
+ if self._record_batches_byte_size > 0:
48
+ self._write_batches()
49
+
50
+ self.writer.close()
51
+
52
+ def _write_batches(self) -> None:
53
+ table = pa.Table.from_batches(self._record_batches, schema=self._schema)
54
+ self._record_batches = []
55
+ self._record_batches_byte_size = 0
56
+ self.writer.write_table(table)
57
+
58
+ def _flush_buffer(self) -> None:
59
+ arrays: list[pa.array] = [[] for _ in range(len(self._schema.names))]
60
+ for x, y in enumerate(self._buffer):
61
+ arrays[x] = pa.array(y, type=self._schema.types[x])
62
+ self._buffer[x] = []
63
+ rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
64
+ self._record_batches.append(rb)
65
+ size = 0
66
+ for x in arrays:
67
+ for b in x.buffers(): # type: ignore
68
+ if b is not None:
69
+ size = size + b.size
70
+ self._record_batches_byte_size = self._record_batches_byte_size + size
lilac/router_concept.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router for the concept database."""
2
+
3
+ from typing import Annotated, Iterable, Optional, cast
4
+
5
+ from fastapi import APIRouter, HTTPException
6
+ from fastapi.params import Depends
7
+ from openai_function_call import OpenAISchema
8
+ from pydantic import BaseModel, Field
9
+
10
+ from .auth import UserInfo, get_session_user
11
+ from .concepts.concept import DRAFT_MAIN, Concept, ConceptMetrics, DraftId, draft_examples
12
+ from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB, ConceptInfo, ConceptUpdate
13
+ from .env import env
14
+ from .router_utils import RouteErrorHandler, server_compute_concept
15
+ from .schema import RichData, SignalInputType
16
+ from .signals.concept_scorer import ConceptScoreSignal
17
+
18
+ router = APIRouter(route_class=RouteErrorHandler)
19
+
20
+
21
+ @router.get('/', response_model_exclude_none=True)
22
+ def get_concepts(
23
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[ConceptInfo]:
24
+ """List the concepts."""
25
+ return DISK_CONCEPT_DB.list(user)
26
+
27
+
28
+ @router.get('/{namespace}/{concept_name}', response_model_exclude_none=True)
29
+ def get_concept(namespace: str,
30
+ concept_name: str,
31
+ draft: Optional[DraftId] = DRAFT_MAIN,
32
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)] = None) -> Concept:
33
+ """Get a concept from a database."""
34
+ concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
35
+ if not concept:
36
+ raise HTTPException(
37
+ status_code=404,
38
+ detail=f'Concept "{namespace}/{concept_name}" was not found or user does not have access.')
39
+
40
+ # Only return the examples from the draft.
41
+ concept.data = draft_examples(concept, draft or DRAFT_MAIN)
42
+
43
+ return concept
44
+
45
+
46
+ class CreateConceptOptions(BaseModel):
47
+ """Options for creating a concept."""
48
+ # Namespace of the concept.
49
+ namespace: str
50
+ # Name of the concept.
51
+ name: str
52
+ # Input type (modality) of the concept.
53
+ type: SignalInputType
54
+ description: Optional[str] = None
55
+
56
+
57
+ @router.post('/create', response_model_exclude_none=True)
58
+ def create_concept(options: CreateConceptOptions,
59
+ user: Annotated[Optional[UserInfo],
60
+ Depends(get_session_user)]) -> Concept:
61
+ """Edit a concept in the database."""
62
+ return DISK_CONCEPT_DB.create(options.namespace, options.name, options.type, options.description,
63
+ user)
64
+
65
+
66
+ @router.post('/{namespace}/{concept_name}', response_model_exclude_none=True)
67
+ def edit_concept(namespace: str, concept_name: str, change: ConceptUpdate,
68
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> Concept:
69
+ """Edit a concept in the database."""
70
+ return DISK_CONCEPT_DB.edit(namespace, concept_name, change, user)
71
+
72
+
73
+ @router.delete('/{namespace}/{concept_name}')
74
+ def delete_concept(namespace: str, concept_name: str,
75
+ user: Annotated[Optional[UserInfo],
76
+ Depends(get_session_user)]) -> None:
77
+ """Deletes the concept from the database."""
78
+ DISK_CONCEPT_DB.remove(namespace, concept_name, user)
79
+
80
+
81
+ class MergeConceptDraftOptions(BaseModel):
82
+ """Merge a draft into main."""
83
+ draft: DraftId
84
+
85
+
86
+ @router.post('/{namespace}/{concept_name}/merge_draft', response_model_exclude_none=True)
87
+ def merge_concept_draft(namespace: str, concept_name: str, options: MergeConceptDraftOptions,
88
+ user: Annotated[Optional[UserInfo],
89
+ Depends(get_session_user)]) -> Concept:
90
+ """Merge a draft in the concept into main."""
91
+ return DISK_CONCEPT_DB.merge_draft(namespace, concept_name, options.draft, user)
92
+
93
+
94
+ class ScoreExample(BaseModel):
95
+ """Example to score along a specific concept."""
96
+ text: Optional[str] = None
97
+ img: Optional[bytes] = None
98
+
99
+
100
+ class ScoreBody(BaseModel):
101
+ """Request body for the score endpoint."""
102
+ examples: list[ScoreExample]
103
+ draft: str = DRAFT_MAIN
104
+
105
+
106
+ class ConceptModelInfo(BaseModel):
107
+ """Information about a concept model."""
108
+ namespace: str
109
+ concept_name: str
110
+ embedding_name: str
111
+ version: int
112
+ metrics: Optional[ConceptMetrics] = None
113
+
114
+
115
+ @router.get('/{namespace}/{concept_name}/model')
116
+ def get_concept_models(
117
+ namespace: str,
118
+ concept_name: str,
119
+ user: Annotated[Optional[UserInfo],
120
+ Depends(get_session_user)] = None) -> list[ConceptModelInfo]:
121
+ """Get a concept model from a database."""
122
+ concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
123
+ if not concept:
124
+ raise HTTPException(
125
+ status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
126
+ models = DISK_CONCEPT_MODEL_DB.get_models(namespace, concept_name, user)
127
+
128
+ for m in models:
129
+ DISK_CONCEPT_MODEL_DB.sync(m, user)
130
+
131
+ return [
132
+ ConceptModelInfo(
133
+ namespace=m.namespace,
134
+ concept_name=m.concept_name,
135
+ embedding_name=m.embedding_name,
136
+ version=m.version,
137
+ metrics=m.get_metrics(concept)) for m in models
138
+ ]
139
+
140
+
141
+ @router.get('/{namespace}/{concept_name}/model/{embedding_name}')
142
+ def get_concept_model(
143
+ namespace: str,
144
+ concept_name: str,
145
+ embedding_name: str,
146
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)] = None) -> ConceptModelInfo:
147
+ """Get a concept model from a database."""
148
+ concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
149
+ if not concept:
150
+ raise HTTPException(
151
+ status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
152
+
153
+ model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name, user=user)
154
+ if not model:
155
+ model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name, user=user)
156
+ DISK_CONCEPT_MODEL_DB.sync(model)
157
+ model_info = ConceptModelInfo(
158
+ namespace=model.namespace,
159
+ concept_name=model.concept_name,
160
+ embedding_name=model.embedding_name,
161
+ version=model.version,
162
+ metrics=model.get_metrics(concept))
163
+ return model_info
164
+
165
+
166
+ @router.post(
167
+ '/{namespace}/{concept_name}/model/{embedding_name}/score', response_model_exclude_none=True)
168
+ def score(namespace: str, concept_name: str, embedding_name: str, body: ScoreBody,
169
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[list[dict]]:
170
+ """Score examples along the specified concept."""
171
+ concept_scorer = ConceptScoreSignal(
172
+ namespace=namespace, concept_name=concept_name, embedding=embedding_name)
173
+ return cast(
174
+ list[list[dict]],
175
+ server_compute_concept(concept_scorer, cast(Iterable[RichData],
176
+ [e.text for e in body.examples]), user))
177
+
178
+
179
+ class Examples(OpenAISchema):
180
+ """Generated text examples."""
181
+ examples: list[str] = Field(..., description='List of generated examples')
182
+
183
+
184
+ @router.get('/generate_examples')
185
+ def generate_examples(description: str) -> list[str]:
186
+ """Generate positive examples for a given concept using an LLM model."""
187
+ try:
188
+ import openai
189
+ except ImportError:
190
+ raise ImportError('Could not import the "openai" python package. '
191
+ 'Please install it with `pip install openai`.')
192
+
193
+ openai.api_key = env('OPENAI_API_KEY')
194
+ completion = openai.ChatCompletion.create(
195
+ model='gpt-3.5-turbo-0613',
196
+ functions=[Examples.openai_schema],
197
+ messages=[
198
+ {
199
+ 'role': 'system',
200
+ 'content': 'You must call the `Examples` function with the generated examples',
201
+ },
202
+ {
203
+ 'role': 'user',
204
+ 'content': f'Write 5 diverse, unnumbered, and concise examples of "{description}"',
205
+ },
206
+ ],
207
+ )
208
+ result = Examples.from_response(completion)
209
+ return result.examples
lilac/router_data_loader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The source loader runner which loads data into parquet files for the app.
2
+
3
+ To run the source loader as a binary directly:
4
+
5
+ poetry run python -m lilac.datasets.loader \
6
+ --dataset_name=$DATASET \
7
+ --output_dir=./data/ \
8
+ --config_path=./datasets/the_movies_dataset.json
9
+ """
10
+ from typing import Any
11
+
12
+ from fastapi import APIRouter, HTTPException, Request
13
+ from pydantic import BaseModel
14
+
15
+ from .auth import get_user_access
16
+ from .data_loader import process_source
17
+ from .env import data_path
18
+ from .router_utils import RouteErrorHandler
19
+ from .sources.source_registry import get_source_cls, registered_sources
20
+ from .tasks import TaskId, task_manager
21
+
22
+ REQUEST_TIMEOUT_SEC = 30 * 60 # 30 mins.
23
+
24
+ router = APIRouter(route_class=RouteErrorHandler)
25
+
26
+
27
+ class ProcessSourceRequest(BaseModel):
28
+ """The interface to the /process_source endpoint."""
29
+ username: str
30
+ dataset_name: str
31
+
32
+
33
+ class SourcesList(BaseModel):
34
+ """The interface to the /process_source endpoint."""
35
+ sources: list[str]
36
+
37
+
38
+ @router.get('/')
39
+ def get_sources() -> SourcesList:
40
+ """Get the list of available sources."""
41
+ sources = registered_sources()
42
+ return SourcesList(sources=list(sources.keys()))
43
+
44
+
45
+ @router.get('/{source_name}')
46
+ def get_source_schema(source_name: str) -> dict[str, Any]:
47
+ """Get the fields for a source."""
48
+ source_cls = get_source_cls(source_name)
49
+ return source_cls.schema()
50
+
51
+
52
+ class LoadDatasetOptions(BaseModel):
53
+ """Options for loading a dataset."""
54
+ namespace: str
55
+ dataset_name: str
56
+ config: dict[str, Any]
57
+
58
+
59
+ class LoadDatasetResponse(BaseModel):
60
+ """Response of the load dataset endpoint."""
61
+ task_id: TaskId
62
+
63
+
64
+ @router.post('/{source_name}/load')
65
+ async def load(source_name: str, options: LoadDatasetOptions,
66
+ request: Request) -> LoadDatasetResponse:
67
+ """Load a dataset."""
68
+ if not get_user_access().create_dataset:
69
+ raise HTTPException(401, 'User does not have access to load a dataset.')
70
+
71
+ source_cls = get_source_cls(source_name)
72
+ source = source_cls(**options.config)
73
+
74
+ task_id = task_manager().task_id(
75
+ name=f'[{options.namespace}/{options.dataset_name}] Load dataset',
76
+ description=f'Loader: {source.name}. \n Config: {source}')
77
+ task_manager().execute(task_id, process_source, data_path(), options.namespace,
78
+ options.dataset_name, source, (task_id, 0))
79
+
80
+ return LoadDatasetResponse(task_id=task_id)
lilac/router_dataset.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router for the dataset database."""
2
+ from typing import Annotated, Optional, Sequence, Union, cast
3
+ from urllib.parse import unquote
4
+
5
+ from fastapi import APIRouter, HTTPException, Response
6
+ from fastapi.params import Depends
7
+ from fastapi.responses import ORJSONResponse
8
+ from pydantic import BaseModel, validator
9
+
10
+ from .auth import UserInfo, get_session_user, get_user_access
11
+ from .data.dataset import BinaryOp
12
+ from .data.dataset import Column as DBColumn
13
+ from .data.dataset import DatasetManifest, DatasetSettings, FeatureListValue, FeatureValue
14
+ from .data.dataset import Filter as PyFilter
15
+ from .data.dataset import (
16
+ GroupsSortBy,
17
+ ListOp,
18
+ Search,
19
+ SelectGroupsResult,
20
+ SelectRowsSchemaResult,
21
+ SortOrder,
22
+ StatsResult,
23
+ UnaryOp,
24
+ )
25
+ from .db_manager import get_dataset, remove_dataset_from_cache
26
+ from .env import data_path
27
+ from .router_utils import RouteErrorHandler
28
+ from .schema import Bin, Path, normalize_path
29
+ from .signals.concept_labels import ConceptLabelsSignal
30
+ from .signals.concept_scorer import ConceptScoreSignal
31
+ from .signals.semantic_similarity import SemanticSimilaritySignal
32
+ from .signals.signal import Signal, TextEmbeddingSignal, TextSignal, resolve_signal
33
+ from .signals.substring_search import SubstringSignal
34
+ from .tasks import TaskId, task_manager
35
+ from .utils import DatasetInfo, list_datasets
36
+
37
+ router = APIRouter(route_class=RouteErrorHandler)
38
+
39
+
40
+ @router.get('/', response_model_exclude_none=True)
41
+ def get_datasets() -> list[DatasetInfo]:
42
+ """List the datasets."""
43
+ return list_datasets(data_path())
44
+
45
+
46
+ class WebManifest(BaseModel):
47
+ """Information about a dataset."""
48
+ dataset_manifest: DatasetManifest
49
+
50
+
51
+ @router.get('/{namespace}/{dataset_name}')
52
+ def get_manifest(namespace: str, dataset_name: str) -> WebManifest:
53
+ """Get the web manifest for the dataset."""
54
+ dataset = get_dataset(namespace, dataset_name)
55
+ res = WebManifest(dataset_manifest=dataset.manifest())
56
+ # Avoids the error that Signal abstract class is not serializable.
57
+ return cast(WebManifest, ORJSONResponse(res.dict(exclude_none=True)))
58
+
59
+
60
+ class ComputeSignalOptions(BaseModel):
61
+ """The request for the compute signal endpoint."""
62
+ signal: Signal
63
+
64
+ # The leaf path to compute the signal on.
65
+ leaf_path: Path
66
+
67
+ @validator('signal', pre=True)
68
+ def parse_signal(cls, signal: dict) -> Signal:
69
+ """Parse a signal to its specific subclass instance."""
70
+ return resolve_signal(signal)
71
+
72
+
73
+ @router.delete('/{namespace}/{dataset_name}')
74
+ def delete_dataset(namespace: str, dataset_name: str) -> None:
75
+ """Delete the dataset."""
76
+ if not get_user_access().dataset.delete_dataset:
77
+ raise HTTPException(401, 'User does not have access to delete this dataset.')
78
+
79
+ dataset = get_dataset(namespace, dataset_name)
80
+ dataset.delete()
81
+ remove_dataset_from_cache(namespace, dataset_name)
82
+
83
+
84
+ class ComputeSignalResponse(BaseModel):
85
+ """Response of the compute signal column endpoint."""
86
+ task_id: TaskId
87
+
88
+
89
+ @router.post('/{namespace}/{dataset_name}/compute_signal')
90
+ def compute_signal(namespace: str, dataset_name: str,
91
+ options: ComputeSignalOptions) -> ComputeSignalResponse:
92
+ """Compute a signal for a dataset."""
93
+ if not get_user_access().dataset.compute_signals:
94
+ raise HTTPException(401, 'User does not have access to compute signals over this dataset.')
95
+
96
+ def _task_compute_signal(namespace: str, dataset_name: str, options_dict: dict,
97
+ task_id: TaskId) -> None:
98
+ # NOTE: We manually call .dict() to avoid the dask serializer, which doesn't call the underlying
99
+ # pydantic serializer.
100
+ options = ComputeSignalOptions(**options_dict)
101
+ dataset = get_dataset(namespace, dataset_name)
102
+ dataset.compute_signal(options.signal, options.leaf_path, task_step_id=(task_id, 0))
103
+
104
+ path_str = '.'.join(map(str, options.leaf_path))
105
+ task_id = task_manager().task_id(
106
+ name=f'[{namespace}/{dataset_name}] Compute signal "{options.signal.name}" on "{path_str}"',
107
+ description=f'Config: {options.signal}')
108
+ task_manager().execute(task_id, _task_compute_signal, namespace, dataset_name, options.dict(),
109
+ task_id)
110
+
111
+ return ComputeSignalResponse(task_id=task_id)
112
+
113
+
114
+ class DeleteSignalOptions(BaseModel):
115
+ """The request for the delete signal endpoint."""
116
+ # The signal path holding the data from the signal.
117
+ signal_path: Path
118
+
119
+
120
+ class DeleteSignalResponse(BaseModel):
121
+ """Response of the compute signal column endpoint."""
122
+ completed: bool
123
+
124
+
125
+ @router.delete('/{namespace}/{dataset_name}/delete_signal')
126
+ def delete_signal(namespace: str, dataset_name: str,
127
+ options: DeleteSignalOptions) -> DeleteSignalResponse:
128
+ """Delete a signal from a dataset."""
129
+ if not get_user_access().dataset.delete_signals:
130
+ raise HTTPException(401, 'User does not have access to delete this signal.')
131
+
132
+ dataset = get_dataset(namespace, dataset_name)
133
+ dataset.delete_signal(options.signal_path)
134
+ return DeleteSignalResponse(completed=True)
135
+
136
+
137
+ class GetStatsOptions(BaseModel):
138
+ """The request for the get stats endpoint."""
139
+ leaf_path: Path
140
+
141
+
142
+ @router.post('/{namespace}/{dataset_name}/stats')
143
+ def get_stats(namespace: str, dataset_name: str, options: GetStatsOptions) -> StatsResult:
144
+ """Get the stats for the dataset."""
145
+ dataset = get_dataset(namespace, dataset_name)
146
+ return dataset.stats(options.leaf_path)
147
+
148
+
149
+ class BinaryFilter(BaseModel):
150
+ """A filter on a column."""
151
+ path: Path
152
+ op: BinaryOp
153
+ value: FeatureValue
154
+
155
+
156
+ class UnaryFilter(BaseModel):
157
+ """A filter on a column."""
158
+ path: Path
159
+ op: UnaryOp
160
+ value: None = None
161
+
162
+
163
+ class ListFilter(BaseModel):
164
+ """A filter on a column."""
165
+ path: Path
166
+ op: ListOp
167
+ value: FeatureListValue
168
+
169
+
170
+ Filter = Union[BinaryFilter, UnaryFilter, ListFilter]
171
+
172
+ AllSignalTypes = Union[ConceptScoreSignal, ConceptLabelsSignal, SubstringSignal,
173
+ SemanticSimilaritySignal, TextEmbeddingSignal, TextSignal, Signal]
174
+
175
+
176
+ # We override the `Column` class so we can add explicitly all signal types for better OpenAPI spec.
177
+ class Column(DBColumn):
178
+ """A column in the dataset."""
179
+ signal_udf: Optional[AllSignalTypes] = None
180
+
181
+
182
+ class SelectRowsOptions(BaseModel):
183
+ """The request for the select rows endpoint."""
184
+ columns: Optional[Sequence[Union[Path, Column]]] = None
185
+ searches: Optional[Sequence[Search]] = None
186
+ filters: Optional[Sequence[Filter]] = None
187
+ sort_by: Optional[Sequence[Path]] = None
188
+ sort_order: Optional[SortOrder] = SortOrder.DESC
189
+ limit: Optional[int] = None
190
+ offset: Optional[int] = None
191
+ combine_columns: Optional[bool] = None
192
+
193
+
194
+ class SelectRowsSchemaOptions(BaseModel):
195
+ """The request for the select rows schema endpoint."""
196
+ columns: Optional[Sequence[Union[Path, Column]]] = None
197
+ searches: Optional[Sequence[Search]] = None
198
+ sort_by: Optional[Sequence[Path]] = None
199
+ sort_order: Optional[SortOrder] = SortOrder.DESC
200
+ combine_columns: Optional[bool] = None
201
+
202
+
203
+ class SelectRowsResponse(BaseModel):
204
+ """The response for the select rows endpoint."""
205
+ rows: list[dict]
206
+ total_num_rows: int
207
+
208
+
209
+ @router.get('/{namespace}/{dataset_name}/select_rows_download', response_model=None)
210
+ def select_rows_download(
211
+ namespace: str, dataset_name: str, url_safe_options: str,
212
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[dict]:
213
+ """Select rows from the dataset database and downloads them."""
214
+ options = SelectRowsOptions.parse_raw(unquote(url_safe_options))
215
+ return select_rows(namespace, dataset_name, options, user).rows
216
+
217
+
218
+ @router.post('/{namespace}/{dataset_name}/select_rows', response_model_exclude_none=True)
219
+ def select_rows(
220
+ namespace: str, dataset_name: str, options: SelectRowsOptions,
221
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SelectRowsResponse:
222
+ """Select rows from the dataset database."""
223
+ dataset = get_dataset(namespace, dataset_name)
224
+
225
+ sanitized_filters = [
226
+ PyFilter(path=normalize_path(f.path), op=f.op, value=f.value) for f in (options.filters or [])
227
+ ]
228
+
229
+ res = dataset.select_rows(
230
+ columns=options.columns,
231
+ searches=options.searches or [],
232
+ filters=sanitized_filters,
233
+ sort_by=options.sort_by,
234
+ sort_order=options.sort_order,
235
+ limit=options.limit,
236
+ offset=options.offset,
237
+ combine_columns=options.combine_columns or False,
238
+ user=user)
239
+
240
+ return SelectRowsResponse(rows=list(res), total_num_rows=res.total_num_rows)
241
+
242
+
243
+ @router.post('/{namespace}/{dataset_name}/select_rows_schema', response_model_exclude_none=True)
244
+ def select_rows_schema(namespace: str, dataset_name: str,
245
+ options: SelectRowsSchemaOptions) -> SelectRowsSchemaResult:
246
+ """Select rows from the dataset database."""
247
+ dataset = get_dataset(namespace, dataset_name)
248
+ return dataset.select_rows_schema(
249
+ columns=options.columns,
250
+ searches=options.searches or [],
251
+ sort_by=options.sort_by,
252
+ sort_order=options.sort_order,
253
+ combine_columns=options.combine_columns or False)
254
+
255
+
256
+ class SelectGroupsOptions(BaseModel):
257
+ """The request for the select groups endpoint."""
258
+ leaf_path: Path
259
+ filters: Optional[Sequence[Filter]] = None
260
+ sort_by: Optional[GroupsSortBy] = GroupsSortBy.COUNT
261
+ sort_order: Optional[SortOrder] = SortOrder.DESC
262
+ limit: Optional[int] = 100
263
+ bins: Optional[list[Bin]] = None
264
+
265
+
266
+ @router.post('/{namespace}/{dataset_name}/select_groups')
267
+ def select_groups(namespace: str, dataset_name: str,
268
+ options: SelectGroupsOptions) -> SelectGroupsResult:
269
+ """Select groups from the dataset database."""
270
+ dataset = get_dataset(namespace, dataset_name)
271
+ sanitized_filters = [
272
+ PyFilter(path=normalize_path(f.path), op=f.op, value=f.value) for f in (options.filters or [])
273
+ ]
274
+ return dataset.select_groups(options.leaf_path, sanitized_filters, options.sort_by,
275
+ options.sort_order, options.limit, options.bins)
276
+
277
+
278
+ @router.get('/{namespace}/{dataset_name}/media')
279
+ def get_media(namespace: str, dataset_name: str, item_id: str, leaf_path: str) -> Response:
280
+ """Get the media for the dataset."""
281
+ dataset = get_dataset(namespace, dataset_name)
282
+ path = tuple(leaf_path.split('.'))
283
+ result = dataset.media(item_id, path)
284
+ # Return the response via HTTP.
285
+ return Response(content=result.data)
286
+
287
+
288
+ @router.get('/{namespace}/{dataset_name}/settings')
289
+ def get_settings(namespace: str, dataset_name: str) -> DatasetSettings:
290
+ """Get the media for the dataset."""
291
+ dataset = get_dataset(namespace, dataset_name)
292
+ return dataset.settings()
293
+
294
+
295
+ @router.post('/{namespace}/{dataset_name}/settings', response_model_exclude_none=True)
296
+ def update_settings(namespace: str, dataset_name: str, settings: DatasetSettings) -> None:
297
+ """Get the media for the dataset."""
298
+ if not get_user_access().dataset.compute_signals:
299
+ raise HTTPException(401, 'User does not have access to update the settings of this dataset.')
300
+
301
+ dataset = get_dataset(namespace, dataset_name)
302
+ dataset.update_settings(settings)
303
+ return None
lilac/router_google_login.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router for Google OAuth2 login."""
2
+
3
+ from urllib.parse import urlparse, urlunparse
4
+
5
+ from authlib.integrations.starlette_client import OAuth, OAuthError
6
+ from fastapi import APIRouter, Request, Response
7
+ from fastapi.responses import HTMLResponse
8
+ from starlette.config import Config
9
+ from starlette.responses import RedirectResponse
10
+
11
+ from .auth import UserInfo
12
+ from .env import env
13
+ from .router_utils import RouteErrorHandler
14
+
15
+ router = APIRouter(route_class=RouteErrorHandler)
16
+
17
+ if env('LILAC_AUTH_ENABLED'):
18
+ oauth = OAuth(
19
+ Config(
20
+ environ={
21
+ 'GOOGLE_CLIENT_ID': env('GOOGLE_CLIENT_ID'),
22
+ 'GOOGLE_CLIENT_SECRET': env('GOOGLE_CLIENT_SECRET')
23
+ }))
24
+ oauth.register(
25
+ name='google',
26
+ server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
27
+ client_kwargs={'scope': 'openid email profile'},
28
+ )
29
+
30
+
31
+ @router.get('/login')
32
+ async def login(request: Request, origin_url: str) -> RedirectResponse:
33
+ """Redirects to Google OAuth login page."""
34
+ auth_path = urlunparse(urlparse(origin_url)._replace(path='/google/auth'))
35
+ return await oauth.google.authorize_redirect(request, auth_path)
36
+
37
+
38
+ @router.get('/auth')
39
+ async def auth(request: Request) -> Response:
40
+ """Handles the Google OAuth callback."""
41
+ try:
42
+ token = await oauth.google.authorize_access_token(request)
43
+ except OAuthError as error:
44
+ return HTMLResponse(f'<h1>{error}</h1>')
45
+ userinfo = token['userinfo']
46
+ request.session['user'] = UserInfo(
47
+ id=str(userinfo['sub']),
48
+ email=userinfo['email'],
49
+ name=userinfo['name'],
50
+ given_name=userinfo['given_name'],
51
+ family_name=userinfo['family_name']).dict()
52
+
53
+ return RedirectResponse(url='/')
54
+
55
+
56
+ @router.get('/logout')
57
+ def logout(request: Request) -> RedirectResponse:
58
+ """Logs the user out."""
59
+ request.session.pop('user', None)
60
+ return RedirectResponse(url='/')
lilac/router_signal.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router for the signal registry."""
2
+
3
+ import math
4
+ from typing import Annotated, Any, Optional
5
+
6
+ from fastapi import APIRouter, Depends
7
+ from pydantic import BaseModel, validator
8
+
9
+ from .auth import UserInfo, get_session_user
10
+ from .router_utils import RouteErrorHandler, server_compute_concept
11
+ from .schema import Field, SignalInputType
12
+ from .signals.concept_scorer import ConceptScoreSignal
13
+ from .signals.signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
14
+
15
+ router = APIRouter(route_class=RouteErrorHandler)
16
+
17
+ EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']
18
+
19
+
20
+ class SignalInfo(BaseModel):
21
+ """Information about a signal."""
22
+ name: str
23
+ input_type: SignalInputType
24
+ json_schema: dict[str, Any]
25
+
26
+
27
+ @router.get('/', response_model_exclude_none=True)
28
+ def get_signals() -> list[SignalInfo]:
29
+ """List the signals."""
30
+ return [
31
+ SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
32
+ for s in SIGNAL_REGISTRY.values()
33
+ if not issubclass(s, TextEmbeddingSignal)
34
+ ]
35
+
36
+
37
+ @router.get('/embeddings', response_model_exclude_none=True)
38
+ def get_embeddings() -> list[SignalInfo]:
39
+ """List the embeddings."""
40
+ embedding_infos = [
41
+ SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
42
+ for s in SIGNAL_REGISTRY.values()
43
+ if issubclass(s, TextEmbeddingSignal)
44
+ ]
45
+
46
+ # Sort the embedding infos by priority.
47
+ embedding_infos = sorted(
48
+ embedding_infos,
49
+ key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
50
+ if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)
51
+
52
+ return embedding_infos
53
+
54
+
55
+ class SignalComputeOptions(BaseModel):
56
+ """The request for the standalone compute signal endpoint."""
57
+ signal: Signal
58
+ # The inputs to compute.
59
+ inputs: list[str]
60
+
61
+ @validator('signal', pre=True)
62
+ def parse_signal(cls, signal: dict) -> Signal:
63
+ """Parse a signal to its specific subclass instance."""
64
+ return resolve_signal(signal)
65
+
66
+
67
+ class SignalComputeResponse(BaseModel):
68
+ """The response for the standalone compute signal endpoint."""
69
+ items: list[Optional[Any]]
70
+
71
+
72
+ @router.post('/compute', response_model_exclude_none=True)
73
+ def compute(
74
+ options: SignalComputeOptions,
75
+ user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
76
+ """Compute a signal over a set of inputs."""
77
+ signal = options.signal
78
+ if isinstance(signal, ConceptScoreSignal):
79
+ result = server_compute_concept(signal, options.inputs, user)
80
+ else:
81
+ signal.setup()
82
+ result = list(signal.compute(options.inputs))
83
+ return SignalComputeResponse(items=result)
84
+
85
+
86
+ class SignalSchemaOptions(BaseModel):
87
+ """The request for the signal schema endpoint."""
88
+ signal: Signal
89
+
90
+ @validator('signal', pre=True)
91
+ def parse_signal(cls, signal: dict) -> Signal:
92
+ """Parse a signal to its specific subclass instance."""
93
+ return resolve_signal(signal)
94
+
95
+
96
+ class SignalSchemaResponse(BaseModel):
97
+ """The response for the signal schema endpoint."""
98
+ fields: Field
99
+
100
+
101
+ @router.post('/schema', response_model_exclude_none=True)
102
+ def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
103
+ """Get the schema for a signal."""
104
+ signal = options.signal
105
+ return SignalSchemaResponse(fields=signal.fields())
lilac/router_tasks.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router for tasks."""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from .router_utils import RouteErrorHandler
6
+ from .tasks import TaskManifest, task_manager
7
+
8
+ router = APIRouter(route_class=RouteErrorHandler)
9
+
10
+
11
+ @router.get('/')
12
+ async def get_task_manifest() -> TaskManifest:
13
+ """Get the tasks, both completed and pending."""
14
+ return await task_manager().manifest()
lilac/router_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for routers."""
2
+
3
+ import traceback
4
+ from typing import Callable, Iterable, Optional
5
+
6
+ from fastapi import HTTPException, Request, Response
7
+ from fastapi.routing import APIRoute
8
+
9
+ from .auth import UserInfo
10
+ from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB
11
+ from .schema import Item, RichData
12
+ from .signals.concept_scorer import ConceptScoreSignal
13
+
14
+
15
+ class RouteErrorHandler(APIRoute):
16
+ """Custom APIRoute that handles application errors and exceptions."""
17
+
18
+ def get_route_handler(self) -> Callable:
19
+ """Get the route handler."""
20
+ original_route_handler = super().get_route_handler()
21
+
22
+ async def custom_route_handler(request: Request) -> Response:
23
+ try:
24
+ return await original_route_handler(request)
25
+ except Exception as ex:
26
+ if isinstance(ex, HTTPException):
27
+ raise ex
28
+
29
+ print('Route error:', request.url)
30
+ print(ex)
31
+ print(traceback.format_exc())
32
+
33
+ # wrap error into pretty 500 exception
34
+ raise HTTPException(status_code=500, detail=traceback.format_exc()) from ex
35
+
36
+ return custom_route_handler
37
+
38
+
39
+ def server_compute_concept(signal: ConceptScoreSignal, examples: Iterable[RichData],
40
+ user: Optional[UserInfo]) -> list[Optional[Item]]:
41
+ """Compute a concept from the REST endpoints."""
42
+ # TODO(nsthorat): Move this to the setup() method in the concept_scorer.
43
+ concept = DISK_CONCEPT_DB.get(signal.namespace, signal.concept_name, user)
44
+ if not concept:
45
+ raise HTTPException(
46
+ status_code=404, detail=f'Concept "{signal.namespace}/{signal.concept_name}" was not found')
47
+ model = DISK_CONCEPT_MODEL_DB.get(
48
+ signal.namespace, signal.concept_name, signal.embedding, user=user)
49
+ if model is None:
50
+ model = DISK_CONCEPT_MODEL_DB.create(
51
+ signal.namespace, signal.concept_name, signal.embedding, user=user)
52
+ DISK_CONCEPT_MODEL_DB.sync(model, user)
53
+ texts = [example or '' for example in examples]
54
+ return list(signal.compute(texts))
lilac/schema.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Item: an individual entry in the dataset."""
2
+
3
+ import csv
4
+ import io
5
+ from collections import deque
6
+ from datetime import datetime
7
+ from enum import Enum
8
+ from typing import Any, Optional, Union, cast
9
+
10
+ import numpy as np
11
+ import pyarrow as pa
12
+ from pydantic import BaseModel, StrictInt, StrictStr, validator
13
+ from typing_extensions import TypedDict
14
+
15
+ MANIFEST_FILENAME = 'manifest.json'
16
+ PARQUET_FILENAME_PREFIX = 'data'
17
+
18
+ # We choose `__rowid__` inspired by the standard `rowid` pseudocolumn in DBs:
19
+ # https://docs.oracle.com/cd/B19306_01/server.102/b14200/pseudocolumns008.htm
20
+ UUID_COLUMN = '__rowid__'
21
+ PATH_WILDCARD = '*'
22
+ VALUE_KEY = '__value__'
23
+ SIGNAL_METADATA_KEY = '__metadata__'
24
+ TEXT_SPAN_START_FEATURE = 'start'
25
+ TEXT_SPAN_END_FEATURE = 'end'
26
+
27
+ EMBEDDING_KEY = 'embedding'
28
+
29
+ # Python doesn't work with recursive types. These types provide some notion of type-safety.
30
+ Scalar = Union[bool, datetime, int, float, str, bytes]
31
+ Item = Any
32
+
33
+ # Contains a string field name, a wildcard for repeateds, or a specific integer index for repeateds.
34
+ # This path represents a path to a particular column.
35
+ # Examples:
36
+ # ['article', 'field'] represents {'article': {'field': VALUES}}
37
+ # ['article', '*', 'field'] represents {'article': [{'field': VALUES}, {'field': VALUES}]}
38
+ # ['article', '0', 'field'] represents {'article': {'field': VALUES}}
39
+ PathTuple = tuple[StrictStr, ...]
40
+ Path = Union[PathTuple, StrictStr]
41
+
42
+ PathKeyedItem = tuple[Path, Item]
43
+
44
+ # These fields are for for python only and not written to a schema.
45
+ RichData = Union[str, bytes]
46
+ VectorKey = tuple[Union[StrictStr, StrictInt], ...]
47
+ PathKey = VectorKey
48
+
49
+
50
+ class DataType(str, Enum):
51
+ """Enum holding the dtype for a field."""
52
+ STRING = 'string'
53
+ # Contains {start, end} offset integers with a reference_column.
54
+ STRING_SPAN = 'string_span'
55
+ BOOLEAN = 'boolean'
56
+
57
+ # Ints.
58
+ INT8 = 'int8'
59
+ INT16 = 'int16'
60
+ INT32 = 'int32'
61
+ INT64 = 'int64'
62
+ UINT8 = 'uint8'
63
+ UINT16 = 'uint16'
64
+ UINT32 = 'uint32'
65
+ UINT64 = 'uint64'
66
+
67
+ # Floats.
68
+ FLOAT16 = 'float16'
69
+ FLOAT32 = 'float32'
70
+ FLOAT64 = 'float64'
71
+
72
+ ### Time ###
73
+ # Time of day (no time zone).
74
+ TIME = 'time'
75
+ # Calendar date (year, month, day), no time zone.
76
+ DATE = 'date'
77
+ # An "Instant" stored as number of microseconds (µs) since 1970-01-01 00:00:00+00 (UTC time zone).
78
+ TIMESTAMP = 'timestamp'
79
+ # Time span, stored as microseconds.
80
+ INTERVAL = 'interval'
81
+
82
+ BINARY = 'binary'
83
+
84
+ EMBEDDING = 'embedding'
85
+
86
+ NULL = 'null'
87
+
88
+ def __repr__(self) -> str:
89
+ return self.value
90
+
91
+
92
+ class SignalInputType(str, Enum):
93
+ """Enum holding the signal input type."""
94
+ TEXT = 'text'
95
+ TEXT_EMBEDDING = 'text_embedding'
96
+ IMAGE = 'image'
97
+
98
+ def __repr__(self) -> str:
99
+ return self.value
100
+
101
+
102
+ SIGNAL_TYPE_TO_VALID_DTYPES: dict[SignalInputType, list[DataType]] = {
103
+ SignalInputType.TEXT: [DataType.STRING, DataType.STRING_SPAN],
104
+ SignalInputType.IMAGE: [DataType.BINARY],
105
+ }
106
+
107
+
108
+ def signal_type_supports_dtype(input_type: SignalInputType, dtype: DataType) -> bool:
109
+ """Returns True if the signal compute type supports the dtype."""
110
+ return dtype in SIGNAL_TYPE_TO_VALID_DTYPES[input_type]
111
+
112
+
113
+ Bin = tuple[str, Optional[Union[float, int]], Optional[Union[float, int]]]
114
+
115
+
116
+ class Field(BaseModel):
117
+ """Holds information for a field in the schema."""
118
+ repeated_field: Optional['Field'] = None
119
+ fields: Optional[dict[str, 'Field']] = None
120
+ dtype: Optional[DataType] = None
121
+ # Defined as the serialized signal when this field is the root result of a signal.
122
+ signal: Optional[dict[str, Any]] = None
123
+ # Maps a named bin to a tuple of (start, end) values.
124
+ bins: Optional[list[Bin]] = None
125
+ categorical: Optional[bool] = None
126
+
127
+ @validator('fields')
128
+ def either_fields_or_repeated_field_is_defined(
129
+ cls, fields: Optional[dict[str, 'Field']], values: dict[str,
130
+ Any]) -> Optional[dict[str, 'Field']]:
131
+ """Error if both `fields` and `repeated_fields` are defined."""
132
+ if not fields:
133
+ return fields
134
+ if values.get('repeated_field'):
135
+ raise ValueError('Both "fields" and "repeated_field" should not be defined')
136
+ if VALUE_KEY in fields:
137
+ raise ValueError(f'{VALUE_KEY} is a reserved field name.')
138
+ return fields
139
+
140
+ @validator('dtype', always=True)
141
+ def infer_default_dtype(cls, dtype: Optional[DataType], values: dict[str,
142
+ Any]) -> Optional[DataType]:
143
+ """Infers the default value for dtype if not explicitly provided."""
144
+ if dtype and values.get('repeated_field'):
145
+ raise ValueError('dtype and repeated_field cannot both be defined.')
146
+ if not values.get('repeated_field') and not values.get('fields') and not dtype:
147
+ raise ValueError('One of "fields", "repeated_field", or "dtype" should be defined')
148
+ return dtype
149
+
150
+ @validator('bins')
151
+ def validate_bins(cls, bins: list[Bin]) -> list[Bin]:
152
+ """Validate the bins."""
153
+ if len(bins) < 2:
154
+ raise ValueError('Please specify at least two bins.')
155
+ _, first_start, _ = bins[0]
156
+ if first_start is not None:
157
+ raise ValueError('The first bin should have a `None` start value.')
158
+ _, _, last_end = bins[-1]
159
+ if last_end is not None:
160
+ raise ValueError('The last bin should have a `None` end value.')
161
+ for i, (_, start, _) in enumerate(bins):
162
+ if i == 0:
163
+ continue
164
+ prev_bin = bins[i - 1]
165
+ _, _, prev_end = prev_bin
166
+ if start != prev_end:
167
+ raise ValueError(
168
+ f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.')
169
+ return bins
170
+
171
+ @validator('categorical')
172
+ def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool:
173
+ """Validate the categorical field."""
174
+ if categorical and is_float(values['dtype']):
175
+ raise ValueError('Categorical fields cannot be float dtypes.')
176
+ return categorical
177
+
178
+ def __str__(self) -> str:
179
+ return _str_field(self, indent=0)
180
+
181
+ def __repr__(self) -> str:
182
+ return f' {self.__class__.__name__}::{self.json(exclude_none=True, indent=2)}'
183
+
184
+
185
+ class Schema(BaseModel):
186
+ """Database schema."""
187
+ fields: dict[str, Field]
188
+ # Cached leafs.
189
+ _leafs: Optional[dict[PathTuple, Field]] = None
190
+
191
+ class Config:
192
+ arbitrary_types_allowed = True
193
+ underscore_attrs_are_private = True
194
+
195
+ @property
196
+ def leafs(self) -> dict[PathTuple, Field]:
197
+ """Return all the leaf fields in the schema. A leaf is defined as a node that contains a value.
198
+
199
+ NOTE: Leafs may contain children. Leafs can be found as any node that has a dtype defined.
200
+ """
201
+ if self._leafs:
202
+ return self._leafs
203
+ result: dict[PathTuple, Field] = {}
204
+ q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))])
205
+ while q:
206
+ path, field = q.popleft()
207
+ if field.dtype:
208
+ # Nodes with dtypes act as leafs. They also may have children.
209
+ result[path] = field
210
+ if field.fields:
211
+ for name, child_field in field.fields.items():
212
+ child_path = (*path, name)
213
+ q.append((child_path, child_field))
214
+ elif field.repeated_field:
215
+ child_path = (*path, PATH_WILDCARD)
216
+ q.append((child_path, field.repeated_field))
217
+
218
+ self._leafs = result
219
+ return result
220
+
221
+ def has_field(self, path: PathTuple) -> bool:
222
+ """Returns if the field is found at the given path."""
223
+ field = cast(Field, self)
224
+ for path_part in path:
225
+ if field.fields:
226
+ field = cast(Field, field.fields.get(path_part))
227
+ if not field:
228
+ return False
229
+ elif field.repeated_field:
230
+ if path_part != PATH_WILDCARD:
231
+ return False
232
+ field = field.repeated_field
233
+ else:
234
+ return False
235
+ return True
236
+
237
+ def get_field(self, path: PathTuple) -> Field:
238
+ """Returns the field at the given path."""
239
+ field = cast(Field, self)
240
+ for name in path:
241
+ if field.fields:
242
+ if name not in field.fields:
243
+ raise ValueError(f'Path {path} not found in schema')
244
+ field = field.fields[name]
245
+ elif field.repeated_field:
246
+ if name != PATH_WILDCARD:
247
+ raise ValueError(f'Invalid path {path}')
248
+ field = field.repeated_field
249
+ else:
250
+ raise ValueError(f'Invalid path {path}')
251
+ return field
252
+
253
+ def __str__(self) -> str:
254
+ return _str_fields(self.fields, indent=0)
255
+
256
+ def __repr__(self) -> str:
257
+ return self.json(exclude_none=True, indent=2)
258
+
259
+
260
+ def schema(schema_like: object) -> Schema:
261
+ """Parse a schema-like object to a Schema object."""
262
+ field = _parse_field_like(schema_like)
263
+ if not field.fields:
264
+ raise ValueError('Schema must have fields')
265
+ return Schema(fields=field.fields)
266
+
267
+
268
+ def field(
269
+ dtype: Optional[Union[DataType, str]] = None,
270
+ signal: Optional[dict] = None,
271
+ fields: Optional[object] = None,
272
+ bins: Optional[list[Bin]] = None,
273
+ categorical: Optional[bool] = None,
274
+ ) -> Field:
275
+ """Parse a field-like object to a Field object."""
276
+ field = _parse_field_like(fields or {}, dtype)
277
+ if signal:
278
+ field.signal = signal
279
+ if dtype:
280
+ if isinstance(dtype, str):
281
+ dtype = DataType(dtype)
282
+ field.dtype = dtype
283
+ if bins:
284
+ field.bins = bins
285
+ if categorical is not None:
286
+ field.categorical = categorical
287
+ return field
288
+
289
+
290
+ class SpanVector(TypedDict):
291
+ """A span with a vector."""
292
+ span: tuple[int, int]
293
+ vector: np.ndarray
294
+
295
+
296
+ def lilac_span(start: int, end: int, metadata: dict[str, Any] = {}) -> Item:
297
+ """Creates a lilac span item, representing a pointer to a slice of text."""
298
+ return {VALUE_KEY: {TEXT_SPAN_START_FEATURE: start, TEXT_SPAN_END_FEATURE: end}, **metadata}
299
+
300
+
301
+ def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> Item:
302
+ """Creates a lilac embedding item, representing a vector with a pointer to a slice of text."""
303
+ return lilac_span(start, end, {EMBEDDING_KEY: embedding})
304
+
305
+
306
+ def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field:
307
+ if isinstance(field_like, Field):
308
+ return field_like
309
+ elif isinstance(field_like, dict):
310
+ fields: dict[str, Field] = {}
311
+ for k, v in field_like.items():
312
+ fields[k] = _parse_field_like(v)
313
+ if isinstance(dtype, str):
314
+ dtype = DataType(dtype)
315
+ return Field(fields=fields or None, dtype=dtype)
316
+ elif isinstance(field_like, str):
317
+ return Field(dtype=DataType(field_like))
318
+ elif isinstance(field_like, list):
319
+ return Field(repeated_field=_parse_field_like(field_like[0], dtype=dtype))
320
+ else:
321
+ raise ValueError(f'Cannot parse field like: {field_like}')
322
+
323
+
324
+ def child_item_from_column_path(item: Item, path: Path) -> Item:
325
+ """Return the last (child) item from a column path."""
326
+ child_item_value = item
327
+ for path_part in path:
328
+ if path_part == PATH_WILDCARD:
329
+ raise ValueError(
330
+ 'child_item_from_column_path cannot be called with a path that contains a repeated '
331
+ f'wildcard: "{path}"')
332
+ # path_part can either be an integer or a string for a dictionary, both of which we can
333
+ # directly index with.
334
+ child_path = int(path_part) if path_part.isdigit() else path_part
335
+ child_item_value = child_item_value[child_path]
336
+ return child_item_value
337
+
338
+
339
+ def column_paths_match(path_match: Path, specific_path: Path) -> bool:
340
+ """Test whether two column paths match.
341
+
342
+ Args:
343
+ path_match: A column path that contains wildcards, and sub-paths. This path will be used for
344
+ testing the second specific path.
345
+ specific_path: A column path that specifically identifies an field.
346
+
347
+ Returns
348
+ Whether specific_path matches the path_match. This will only match when the
349
+ paths are equal length. If a user wants to enrich everything with an array, they must use the
350
+ path wildcard '*' in their patch match.
351
+ """
352
+ if isinstance(path_match, str):
353
+ path_match = (path_match,)
354
+ if isinstance(specific_path, str):
355
+ specific_path = (specific_path,)
356
+
357
+ if len(path_match) != len(specific_path):
358
+ return False
359
+
360
+ for path_match_p, specific_path_p in zip(path_match, specific_path):
361
+ if path_match_p == PATH_WILDCARD:
362
+ continue
363
+
364
+ if path_match_p != specific_path_p:
365
+ return False
366
+
367
+ return True
368
+
369
+
370
+ def normalize_path(path: Path) -> PathTuple:
371
+ """Normalizes a dot seperated path, but ignores dots inside quotes, like regular SQL.
372
+
373
+ Examples
374
+ - 'a.b.c' will be parsed as ('a', 'b', 'c').
375
+ - '"a.b".c' will be parsed as ('a.b', 'c').
376
+ - '"a".b.c' will be parsed as ('a', 'b', 'c').
377
+ """
378
+ if isinstance(path, str):
379
+ return tuple(next(csv.reader(io.StringIO(path), delimiter='.')))
380
+ return path
381
+
382
+
383
+ class ImageInfo(BaseModel):
384
+ """Info about an individual image."""
385
+ path: Path
386
+
387
+
388
+ class SourceManifest(BaseModel):
389
+ """The manifest that describes the dataset run, including schema and parquet files."""
390
+ # List of a parquet filepaths storing the data. The paths can be relative to `manifest.json`.
391
+ files: list[str]
392
+ # The data schema.
393
+ data_schema: Schema
394
+
395
+ # Image information for the dataset.
396
+ images: Optional[list[ImageInfo]] = None
397
+
398
+
399
+ def _str_fields(fields: dict[str, Field], indent: int) -> str:
400
+ prefix = ' ' * indent
401
+ out: list[str] = []
402
+ for name, field in fields.items():
403
+ out.append(f'{prefix}{name}:{_str_field(field, indent=indent + 2)}')
404
+ return '\n'.join(out)
405
+
406
+
407
+ def _str_field(field: Field, indent: int) -> str:
408
+ if field.fields:
409
+ prefix = '\n' if indent > 0 else ''
410
+ return f'{prefix}{_str_fields(field.fields, indent)}'
411
+ if field.repeated_field:
412
+ return f' list({_str_field(field.repeated_field, indent)})'
413
+ return f' {cast(DataType, field.dtype)}'
414
+
415
+
416
+ def dtype_to_arrow_schema(dtype: DataType) -> Union[pa.Schema, pa.DataType]:
417
+ """Convert the dtype to an arrow dtype."""
418
+ if dtype == DataType.STRING:
419
+ return pa.string()
420
+ elif dtype == DataType.BOOLEAN:
421
+ return pa.bool_()
422
+ elif dtype == DataType.FLOAT16:
423
+ return pa.float16()
424
+ elif dtype == DataType.FLOAT32:
425
+ return pa.float32()
426
+ elif dtype == DataType.FLOAT64:
427
+ return pa.float64()
428
+ elif dtype == DataType.INT8:
429
+ return pa.int8()
430
+ elif dtype == DataType.INT16:
431
+ return pa.int16()
432
+ elif dtype == DataType.INT32:
433
+ return pa.int32()
434
+ elif dtype == DataType.INT64:
435
+ return pa.int64()
436
+ elif dtype == DataType.UINT8:
437
+ return pa.uint8()
438
+ elif dtype == DataType.UINT16:
439
+ return pa.uint16()
440
+ elif dtype == DataType.UINT32:
441
+ return pa.uint32()
442
+ elif dtype == DataType.UINT64:
443
+ return pa.uint64()
444
+ elif dtype == DataType.BINARY:
445
+ return pa.binary()
446
+ elif dtype == DataType.TIME:
447
+ return pa.time64()
448
+ elif dtype == DataType.DATE:
449
+ return pa.date64()
450
+ elif dtype == DataType.TIMESTAMP:
451
+ return pa.timestamp('us')
452
+ elif dtype == DataType.INTERVAL:
453
+ return pa.duration('us')
454
+ elif dtype == DataType.EMBEDDING:
455
+ # We reserve an empty column for embeddings in parquet files so they can be queried.
456
+ # The values are *not* filled out. If parquet and duckdb support embeddings in the future, we
457
+ # can set this dtype to the relevant pyarrow type.
458
+ return pa.null()
459
+ elif dtype == DataType.STRING_SPAN:
460
+ return pa.struct({
461
+ VALUE_KEY: pa.struct({
462
+ TEXT_SPAN_START_FEATURE: pa.int32(),
463
+ TEXT_SPAN_END_FEATURE: pa.int32()
464
+ })
465
+ })
466
+ elif dtype == DataType.NULL:
467
+ return pa.null()
468
+ else:
469
+ raise ValueError(f'Can not convert dtype "{dtype}" to arrow dtype')
470
+
471
+
472
+ def schema_to_arrow_schema(schema: Union[Schema, Field]) -> pa.Schema:
473
+ """Convert our schema to arrow schema."""
474
+ arrow_schema = cast(pa.Schema, _schema_to_arrow_schema_impl(schema))
475
+ arrow_fields = {field.name: field.type for field in arrow_schema}
476
+ return pa.schema(arrow_fields)
477
+
478
+
479
+ def _schema_to_arrow_schema_impl(schema: Union[Schema, Field]) -> Union[pa.Schema, pa.DataType]:
480
+ """Convert a schema to an apache arrow schema."""
481
+ if schema.fields:
482
+ arrow_fields: dict[str, Union[pa.Schema, pa.DataType]] = {}
483
+ for name, field in schema.fields.items():
484
+ if name == UUID_COLUMN:
485
+ arrow_schema = dtype_to_arrow_schema(cast(DataType, field.dtype))
486
+ else:
487
+ arrow_schema = _schema_to_arrow_schema_impl(field)
488
+ arrow_fields[name] = arrow_schema
489
+
490
+ if isinstance(schema, Schema):
491
+ # Top-level schemas do not have __value__ fields.
492
+ return pa.schema(arrow_fields)
493
+ else:
494
+ # When nodes have both dtype and children, we add __value__ alongside the fields.
495
+ if schema.dtype:
496
+ value_schema = dtype_to_arrow_schema(schema.dtype)
497
+ if schema.dtype == DataType.STRING_SPAN:
498
+ value_schema = value_schema[VALUE_KEY].type
499
+ arrow_fields[VALUE_KEY] = value_schema
500
+
501
+ return pa.struct(arrow_fields)
502
+
503
+ field = cast(Field, schema)
504
+ if field.repeated_field:
505
+ return pa.list_(_schema_to_arrow_schema_impl(field.repeated_field))
506
+
507
+ return dtype_to_arrow_schema(cast(DataType, field.dtype))
508
+
509
+
510
+ def arrow_dtype_to_dtype(arrow_dtype: pa.DataType) -> DataType:
511
+ """Convert arrow dtype to our dtype."""
512
+ # Ints.
513
+ if arrow_dtype == pa.int8():
514
+ return DataType.INT8
515
+ elif arrow_dtype == pa.int16():
516
+ return DataType.INT16
517
+ elif arrow_dtype == pa.int32():
518
+ return DataType.INT32
519
+ elif arrow_dtype == pa.int64():
520
+ return DataType.INT64
521
+ elif arrow_dtype == pa.uint8():
522
+ return DataType.UINT8
523
+ elif arrow_dtype == pa.uint16():
524
+ return DataType.UINT16
525
+ elif arrow_dtype == pa.uint32():
526
+ return DataType.UINT32
527
+ elif arrow_dtype == pa.uint64():
528
+ return DataType.UINT64
529
+ # Floats.
530
+ elif arrow_dtype == pa.float16():
531
+ return DataType.FLOAT16
532
+ elif arrow_dtype == pa.float32():
533
+ return DataType.FLOAT32
534
+ elif arrow_dtype == pa.float64():
535
+ return DataType.FLOAT64
536
+ # Time.
537
+ elif pa.types.is_time(arrow_dtype):
538
+ return DataType.TIME
539
+ elif pa.types.is_date(arrow_dtype):
540
+ return DataType.DATE
541
+ elif pa.types.is_timestamp(arrow_dtype):
542
+ return DataType.TIMESTAMP
543
+ elif pa.types.is_duration(arrow_dtype):
544
+ return DataType.INTERVAL
545
+ # Others.
546
+ elif arrow_dtype == pa.string():
547
+ return DataType.STRING
548
+ elif pa.types.is_binary(arrow_dtype) or pa.types.is_fixed_size_binary(arrow_dtype):
549
+ return DataType.BINARY
550
+ elif pa.types.is_boolean(arrow_dtype):
551
+ return DataType.BOOLEAN
552
+ elif arrow_dtype == pa.null():
553
+ return DataType.NULL
554
+ else:
555
+ raise ValueError(f'Can not convert arrow dtype "{arrow_dtype}" to our dtype')
556
+
557
+
558
+ def arrow_schema_to_schema(schema: pa.Schema) -> Schema:
559
+ """Convert arrow schema to our schema."""
560
+ # TODO(nsthorat): Change this implementation to allow more complicated reading of arrow schemas
561
+ # into our schema by inferring values when {__value__: value} is present in the pyarrow schema.
562
+ # This isn't necessary today as this util is only needed by sources which do not have data in the
563
+ # lilac format.
564
+ return cast(Schema, _arrow_schema_to_schema_impl(schema))
565
+
566
+
567
+ def _arrow_schema_to_schema_impl(schema: Union[pa.Schema, pa.DataType]) -> Union[Schema, Field]:
568
+ """Convert an apache arrow schema to our schema."""
569
+ if isinstance(schema, (pa.Schema, pa.StructType)):
570
+ fields: dict[str, Field] = {
571
+ field.name: cast(Field, _arrow_schema_to_schema_impl(field.type)) for field in schema
572
+ }
573
+ return Schema(fields=fields) if isinstance(schema, pa.Schema) else Field(fields=fields)
574
+ elif isinstance(schema, pa.ListType):
575
+ return Field(repeated_field=cast(Field, _arrow_schema_to_schema_impl(schema.value_field.type)))
576
+ else:
577
+ return Field(dtype=arrow_dtype_to_dtype(schema))
578
+
579
+
580
+ def is_float(dtype: DataType) -> bool:
581
+ """Check if a dtype is a float dtype."""
582
+ return dtype in [DataType.FLOAT16, DataType.FLOAT32, DataType.FLOAT64]
583
+
584
+
585
+ def is_integer(dtype: DataType) -> bool:
586
+ """Check if a dtype is an integer dtype."""
587
+ return dtype in [
588
+ DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16,
589
+ DataType.UINT32, DataType.UINT64
590
+ ]
591
+
592
+
593
+ def is_temporal(dtype: DataType) -> bool:
594
+ """Check if a dtype is a temporal dtype."""
595
+ return dtype in [DataType.TIME, DataType.DATE, DataType.TIMESTAMP, DataType.INTERVAL]
596
+
597
+
598
+ def is_ordinal(dtype: DataType) -> bool:
599
+ """Check if a dtype is an ordinal dtype."""
600
+ return is_float(dtype) or is_integer(dtype) or is_temporal(dtype)