Spaces:
Running
Running
Push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +10 -0
- .env +40 -0
- .env.demo +4 -0
- .gitattributes +0 -35
- .gitignore +5 -0
- Dockerfile +29 -0
- LICENSE +201 -0
- README.md +6 -9
- lilac/.gitignore +1 -0
- lilac/__init__.py +33 -0
- lilac/auth.py +87 -0
- lilac/batch_utils.py +92 -0
- lilac/cli.py +39 -0
- lilac/concepts/__init__.py +0 -0
- lilac/concepts/concept.py +330 -0
- lilac/concepts/db_concept.py +520 -0
- lilac/config.py +80 -0
- lilac/conftest.py +28 -0
- lilac/data/__init__.py +9 -0
- lilac/data/dataset.py +485 -0
- lilac/data/dataset_duckdb.py +1717 -0
- lilac/data/dataset_test_utils.py +127 -0
- lilac/data/dataset_utils.py +308 -0
- lilac/data/duckdb_utils.py +25 -0
- lilac/data_loader.py +110 -0
- lilac/db_manager.py +42 -0
- lilac/embeddings/__init__.py +0 -0
- lilac/embeddings/cohere.py +59 -0
- lilac/embeddings/default_vector_stores.py +10 -0
- lilac/embeddings/embedding.py +110 -0
- lilac/embeddings/gte.py +63 -0
- lilac/embeddings/openai.py +68 -0
- lilac/embeddings/palm.py +62 -0
- lilac/embeddings/sbert.py +38 -0
- lilac/embeddings/transformer_utils.py +35 -0
- lilac/embeddings/vector_store.py +200 -0
- lilac/embeddings/vector_store_hnsw.py +106 -0
- lilac/embeddings/vector_store_numpy.py +92 -0
- lilac/env.py +63 -0
- lilac/load.py +214 -0
- lilac/make_openapi.py +29 -0
- lilac/parquet_writer.py +70 -0
- lilac/router_concept.py +209 -0
- lilac/router_data_loader.py +80 -0
- lilac/router_dataset.py +303 -0
- lilac/router_google_login.py +60 -0
- lilac/router_signal.py +105 -0
- lilac/router_tasks.py +14 -0
- lilac/router_utils.py +54 -0
- lilac/schema.py +600 -0
.dockerignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
**/__pycache__
|
3 |
+
**/*.pyc
|
4 |
+
**/*.pyo
|
5 |
+
**/*.pyd
|
6 |
+
# Ignore unit tests.
|
7 |
+
**/*_test.py
|
8 |
+
|
9 |
+
# Mac OS.
|
10 |
+
.DS_Store
|
.env
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# To overwrite these variables, create a .env.local file
|
2 |
+
|
3 |
+
# The path to the directory where the data will be downloaded on machine
|
4 |
+
LILAC_DATA_PATH=./data
|
5 |
+
|
6 |
+
# Set to 1 for duckdb to use views instead of materialized tables (lower memory usage, but slower).
|
7 |
+
DUCKDB_USE_VIEWS=0
|
8 |
+
|
9 |
+
# Set to true to enable read-only mode, disabling the ability to add datasets & compute dataset
|
10 |
+
# signals.
|
11 |
+
# LILAC_AUTH_ENABLED=true
|
12 |
+
|
13 |
+
# Variables that can be set in .env.local
|
14 |
+
#
|
15 |
+
# Get key from https://dashboard.cohere.ai/api-keys
|
16 |
+
# COHERE_API_KEY=
|
17 |
+
|
18 |
+
# GCS_REGION=
|
19 |
+
# GCS_ACCESS_KEY=
|
20 |
+
# GCS_SECRET_KEY=
|
21 |
+
|
22 |
+
# Get key from https://platform.openai.com/account/api-keys
|
23 |
+
# OPENAI_API_KEY=
|
24 |
+
# Get key from https://makersuite.google.com/app/apikey
|
25 |
+
# PALM_API_KEY=
|
26 |
+
|
27 |
+
# HuggingFace demos: machine that uploads to HuggingFace.
|
28 |
+
|
29 |
+
# For authenticating with HuggingFace to deploy to a Space.
|
30 |
+
# HF_USERNAME=
|
31 |
+
# The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
|
32 |
+
# HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
|
33 |
+
|
34 |
+
# For Google-login. This is generated from the Google Cloud Console for a web client.
|
35 |
+
# See: https://developers.google.com/identity/protocols/oauth2
|
36 |
+
GOOGLE_CLIENT_ID='279475920249-i8llm8vbos1vj5m1qocir8narb3r0enu.apps.googleusercontent.com'
|
37 |
+
# The client secret of the above client.
|
38 |
+
# GOOGLE_CLIENT_SECRET=
|
39 |
+
# A random string for oauth sessions.
|
40 |
+
# LILAC_OAUTH_SECRET_KEY=
|
.env.demo
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LILAC_DATA_PATH='/data'
|
2 |
+
HF_HOME='/data/.huggingface'
|
3 |
+
TRANSFORMERS_CACHE='/data/.cache'
|
4 |
+
XDG_CACHE_HOME='/data/.cache'
|
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
**/*.pyc
|
3 |
+
**/*.pyo
|
4 |
+
**/*.pyd
|
5 |
+
**/*_test.py
|
Dockerfile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NOTE: When we upgrade to 3.11 we can use a slimmer docker image which comes with gcc.
|
2 |
+
FROM python:3.9-bullseye
|
3 |
+
|
4 |
+
# Allow statements and log messages to immediately appear in the Knative logs
|
5 |
+
ENV PYTHONUNBUFFERED True
|
6 |
+
|
7 |
+
# Set the working directory in the container.
|
8 |
+
WORKDIR /server
|
9 |
+
|
10 |
+
# Install the dependencies. This requires exporting requirements.txt from poetry first, which
|
11 |
+
# happens from ./build_docker.sh.
|
12 |
+
COPY requirements.txt .
|
13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
+
|
15 |
+
COPY .env .
|
16 |
+
COPY .env.demo .
|
17 |
+
COPY LICENSE .
|
18 |
+
|
19 |
+
# Copy python files.
|
20 |
+
COPY /lilac ./lilac/
|
21 |
+
|
22 |
+
# Copy the data files. We use glob so docker copy won't fail if the directory doesn't exist.
|
23 |
+
COPY /dat[a] ./data/
|
24 |
+
|
25 |
+
CMD [ \
|
26 |
+
"gunicorn", "lilac.server:app", \
|
27 |
+
"--bind", "0.0.0.0:5432", \
|
28 |
+
"-k", "uvicorn.workers.UvicornWorker" \
|
29 |
+
]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Lilac AI Inc.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Lilac
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
-
|
8 |
-
|
9 |
-
---
|
10 |
-
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Lilac Blueprint
|
3 |
+
emoji: 🌷
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: purple
|
6 |
sdk: docker
|
7 |
+
app_port: 5432
|
8 |
+
---
|
|
|
|
|
|
lilac/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web/
|
lilac/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import metadata
|
2 |
+
|
3 |
+
from .data import * # noqa: F403
|
4 |
+
from .data.dataset_duckdb import DatasetDuckDB
|
5 |
+
from .data_loader import create_dataset
|
6 |
+
from .db_manager import get_dataset, set_default_dataset_cls
|
7 |
+
from .embeddings.default_vector_stores import register_default_vector_stores
|
8 |
+
from .server import start_server, stop_server
|
9 |
+
from .signals import * # noqa: F403
|
10 |
+
from .signals.default_signals import register_default_signals
|
11 |
+
from .sources import * # noqa: F403
|
12 |
+
from .sources.default_sources import register_default_sources
|
13 |
+
|
14 |
+
try:
|
15 |
+
__version__ = metadata.version('lilacai')
|
16 |
+
except metadata.PackageNotFoundError:
|
17 |
+
__version__ = ''
|
18 |
+
|
19 |
+
register_default_sources()
|
20 |
+
register_default_signals()
|
21 |
+
register_default_vector_stores()
|
22 |
+
set_default_dataset_cls(DatasetDuckDB)
|
23 |
+
|
24 |
+
# Avoids polluting the results of dir(__package__).
|
25 |
+
del (metadata, register_default_sources, register_default_signals, set_default_dataset_cls,
|
26 |
+
DatasetDuckDB)
|
27 |
+
|
28 |
+
__all__ = [
|
29 |
+
'start_server',
|
30 |
+
'stop_server',
|
31 |
+
'create_dataset',
|
32 |
+
'get_dataset',
|
33 |
+
]
|
lilac/auth.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Authentication and ACL configuration."""
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from fastapi import Request
|
6 |
+
from pydantic import BaseModel, ValidationError
|
7 |
+
|
8 |
+
from .env import env
|
9 |
+
|
10 |
+
|
11 |
+
class ConceptAuthorizationException(Exception):
|
12 |
+
"""Authorization exceptions thrown by the concept database."""
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class DatasetUserAccess(BaseModel):
|
17 |
+
"""User access for datasets."""
|
18 |
+
# Whether the user can compute a signal.
|
19 |
+
compute_signals: bool
|
20 |
+
# Whether the user can delete a dataset.
|
21 |
+
delete_dataset: bool
|
22 |
+
# Whether the user can delete a signal.
|
23 |
+
delete_signals: bool
|
24 |
+
# Whether the user can update settings.
|
25 |
+
update_settings: bool
|
26 |
+
|
27 |
+
|
28 |
+
class ConceptUserAccess(BaseModel):
|
29 |
+
"""User access for concepts."""
|
30 |
+
# Whether the user can delete any concept (not their own).
|
31 |
+
delete_any_concept: bool
|
32 |
+
|
33 |
+
|
34 |
+
class UserAccess(BaseModel):
|
35 |
+
"""User access."""
|
36 |
+
create_dataset: bool
|
37 |
+
|
38 |
+
# TODO(nsthorat): Make this keyed to each dataset and concept.
|
39 |
+
dataset: DatasetUserAccess
|
40 |
+
concept: ConceptUserAccess
|
41 |
+
|
42 |
+
|
43 |
+
class UserInfo(BaseModel):
|
44 |
+
"""User information."""
|
45 |
+
id: str
|
46 |
+
email: str
|
47 |
+
name: str
|
48 |
+
given_name: str
|
49 |
+
family_name: str
|
50 |
+
|
51 |
+
|
52 |
+
class AuthenticationInfo(BaseModel):
|
53 |
+
"""Authentication information for the user."""
|
54 |
+
user: Optional[UserInfo] = None
|
55 |
+
access: UserAccess
|
56 |
+
auth_enabled: bool
|
57 |
+
|
58 |
+
|
59 |
+
def get_session_user(request: Request) -> Optional[UserInfo]:
|
60 |
+
"""Get the user from the session."""
|
61 |
+
if not env('LILAC_AUTH_ENABLED'):
|
62 |
+
return None
|
63 |
+
user_info_dict = request.session.get('user', None)
|
64 |
+
if user_info_dict:
|
65 |
+
try:
|
66 |
+
return UserInfo.parse_obj(user_info_dict)
|
67 |
+
except ValidationError:
|
68 |
+
return None
|
69 |
+
return None
|
70 |
+
|
71 |
+
|
72 |
+
def get_user_access() -> UserAccess:
|
73 |
+
"""Get the user access."""
|
74 |
+
auth_enabled = env('LILAC_AUTH_ENABLED')
|
75 |
+
if isinstance(auth_enabled, str):
|
76 |
+
auth_enabled = auth_enabled.lower() == 'true'
|
77 |
+
if auth_enabled:
|
78 |
+
return UserAccess(
|
79 |
+
create_dataset=False,
|
80 |
+
dataset=DatasetUserAccess(
|
81 |
+
compute_signals=False, delete_dataset=False, delete_signals=False, update_settings=False),
|
82 |
+
concept=ConceptUserAccess(delete_any_concept=False))
|
83 |
+
return UserAccess(
|
84 |
+
create_dataset=True,
|
85 |
+
dataset=DatasetUserAccess(
|
86 |
+
compute_signals=True, delete_dataset=True, delete_signals=True, update_settings=True),
|
87 |
+
concept=ConceptUserAccess(delete_any_concept=True))
|
lilac/batch_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for the python server."""
|
2 |
+
import itertools
|
3 |
+
from typing import Any, Callable, Generator, Iterable, Iterator, TypeVar, Union, cast
|
4 |
+
|
5 |
+
from .schema import Item
|
6 |
+
from .utils import chunks, is_primitive
|
7 |
+
|
8 |
+
|
9 |
+
def _deep_flatten(input: Union[Iterator, object],
|
10 |
+
is_primitive_predicate: Callable[[object], bool]) -> Generator:
|
11 |
+
"""Flattens a nested iterable."""
|
12 |
+
if is_primitive_predicate(input):
|
13 |
+
yield input
|
14 |
+
elif isinstance(input, dict):
|
15 |
+
yield input
|
16 |
+
elif is_primitive(input):
|
17 |
+
yield input
|
18 |
+
else:
|
19 |
+
for elem in cast(Iterator, input):
|
20 |
+
yield from _deep_flatten(elem, is_primitive_predicate)
|
21 |
+
|
22 |
+
|
23 |
+
def deep_flatten(input: Union[Iterator, Iterable],
|
24 |
+
is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterator:
|
25 |
+
"""Flattens a deeply nested iterator.
|
26 |
+
|
27 |
+
Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
|
28 |
+
what is a primitive.
|
29 |
+
"""
|
30 |
+
return _deep_flatten(input, is_primitive_predicate)
|
31 |
+
|
32 |
+
|
33 |
+
def _deep_unflatten(flat_input: Iterator[list[object]], original_input: Union[Iterable, object],
|
34 |
+
is_primitive_predicate: Callable[[object], bool]) -> Union[list, dict]:
|
35 |
+
"""Unflattens a deeply flattened iterable according to the original iterable's structure."""
|
36 |
+
if is_primitive_predicate(original_input):
|
37 |
+
return next(flat_input)
|
38 |
+
else:
|
39 |
+
values: Iterable
|
40 |
+
if isinstance(original_input, dict):
|
41 |
+
values = original_input.values()
|
42 |
+
else:
|
43 |
+
values = cast(Iterable, original_input)
|
44 |
+
return [_deep_unflatten(flat_input, orig_elem, is_primitive_predicate) for orig_elem in values]
|
45 |
+
|
46 |
+
|
47 |
+
def deep_unflatten(flat_input: Union[Iterable, Iterator],
|
48 |
+
original_input: Union[Iterable, object],
|
49 |
+
is_primitive_predicate: Callable[[object], bool] = is_primitive) -> list:
|
50 |
+
"""Unflattens a deeply flattened iterable according to the original iterable's structure."""
|
51 |
+
return cast(list, _deep_unflatten(iter(flat_input), original_input, is_primitive_predicate))
|
52 |
+
|
53 |
+
|
54 |
+
TFlatten = TypeVar('TFlatten')
|
55 |
+
|
56 |
+
|
57 |
+
def flatten(inputs: Iterable[Iterable[TFlatten]]) -> Iterator[TFlatten]:
|
58 |
+
"""Flattens a nested iterator.
|
59 |
+
|
60 |
+
Only supports flattening one level deep.
|
61 |
+
"""
|
62 |
+
for input in inputs:
|
63 |
+
yield from input
|
64 |
+
|
65 |
+
|
66 |
+
TUnflatten = TypeVar('TUnflatten')
|
67 |
+
|
68 |
+
|
69 |
+
def unflatten(flat_inputs: Union[Iterable[TUnflatten], Iterator[TUnflatten]],
|
70 |
+
original_inputs: Iterable[Iterable[Any]]) -> Iterator[list[TUnflatten]]:
|
71 |
+
"""Unflattens a flattened iterable according to the original iterable's structure."""
|
72 |
+
flat_inputs_iter = iter(flat_inputs)
|
73 |
+
for original_input in original_inputs:
|
74 |
+
yield [next(flat_inputs_iter) for _ in original_input]
|
75 |
+
|
76 |
+
|
77 |
+
TFlatBatchedInput = TypeVar('TFlatBatchedInput')
|
78 |
+
TFlatBatchedOutput = TypeVar('TFlatBatchedOutput')
|
79 |
+
|
80 |
+
|
81 |
+
def flat_batched_compute(input: Iterable[Iterable[TFlatBatchedInput]],
|
82 |
+
f: Callable[[list[TFlatBatchedInput]], Iterable[TFlatBatchedOutput]],
|
83 |
+
batch_size: int) -> Iterable[Iterable[TFlatBatchedOutput]]:
|
84 |
+
"""Flatten the input, batched call f, and return the output unflattened."""
|
85 |
+
# Tee the input so we can use it twice for the input and output shapes.
|
86 |
+
input_1, input_2 = itertools.tee(input, 2)
|
87 |
+
batches = chunks(flatten(input_1), batch_size)
|
88 |
+
batched_outputs = flatten((f(batch) for batch in batches))
|
89 |
+
return unflatten(batched_outputs, input_2)
|
90 |
+
|
91 |
+
|
92 |
+
TBatchSpanVectorOutput = TypeVar('TBatchSpanVectorOutput', bound=Item)
|
lilac/cli.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Lilac CLI."""
|
2 |
+
|
3 |
+
import click
|
4 |
+
|
5 |
+
from . import __version__
|
6 |
+
from .load import load_command as load
|
7 |
+
from .server import start_server
|
8 |
+
|
9 |
+
|
10 |
+
@click.command()
|
11 |
+
@click.option(
|
12 |
+
'--host',
|
13 |
+
help='The host address where the web server will listen to.',
|
14 |
+
default='0.0.0.0',
|
15 |
+
type=str)
|
16 |
+
@click.option('--port', help='The port number of the web-server', type=int, default=5432)
|
17 |
+
def start(host: str, port: int) -> None:
|
18 |
+
"""Starts the Lilac web server."""
|
19 |
+
start_server(host=host, port=port, open=True)
|
20 |
+
|
21 |
+
|
22 |
+
@click.command()
|
23 |
+
def version() -> None:
|
24 |
+
"""Prints the version of Lilac."""
|
25 |
+
print(__version__)
|
26 |
+
|
27 |
+
|
28 |
+
@click.group()
|
29 |
+
def cli() -> None:
|
30 |
+
"""Lilac CLI."""
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
cli.add_command(start)
|
35 |
+
cli.add_command(version)
|
36 |
+
cli.add_command(load)
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
cli()
|
lilac/concepts/__init__.py
ADDED
File without changes
|
lilac/concepts/concept.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Defines the concept and the concept models."""
|
2 |
+
import dataclasses
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Callable, Literal, Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from joblib import Parallel, delayed
|
8 |
+
from pydantic import BaseModel, validator
|
9 |
+
from scipy.interpolate import interp1d
|
10 |
+
from sklearn.base import clone
|
11 |
+
from sklearn.linear_model import LogisticRegression
|
12 |
+
from sklearn.metrics import precision_recall_curve, roc_auc_score
|
13 |
+
from sklearn.model_selection import KFold
|
14 |
+
|
15 |
+
from ..embeddings.embedding import get_embed_fn
|
16 |
+
from ..schema import SignalInputType
|
17 |
+
from ..signals.signal import TextEmbeddingSignal, get_signal_cls
|
18 |
+
from ..utils import DebugTimer
|
19 |
+
|
20 |
+
LOCAL_CONCEPT_NAMESPACE = 'local'
|
21 |
+
|
22 |
+
# The maximum number of cross-validation models to train.
|
23 |
+
MAX_NUM_CROSS_VAL_MODELS = 15
|
24 |
+
# The β weight to use for the F-beta score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html
|
25 |
+
# β = 0.5 means we value precision 2x as much as recall.
|
26 |
+
# β = 2 means we value recall 2x as much as precision.
|
27 |
+
F_BETA_WEIGHT = 0.5
|
28 |
+
|
29 |
+
|
30 |
+
class ExampleOrigin(BaseModel):
|
31 |
+
"""The origin of an example."""
|
32 |
+
# The namespace that holds the dataset.
|
33 |
+
dataset_namespace: str
|
34 |
+
|
35 |
+
# The name of the dataset.
|
36 |
+
dataset_name: str
|
37 |
+
|
38 |
+
# The id of row in the dataset that the example was added from.
|
39 |
+
dataset_row_id: str
|
40 |
+
|
41 |
+
|
42 |
+
DraftId = Union[Literal['main'], str]
|
43 |
+
DRAFT_MAIN = 'main'
|
44 |
+
|
45 |
+
|
46 |
+
class ExampleIn(BaseModel):
|
47 |
+
"""An example in a concept without the id (used for adding new examples)."""
|
48 |
+
label: bool
|
49 |
+
text: Optional[str] = None
|
50 |
+
img: Optional[bytes] = None
|
51 |
+
origin: Optional[ExampleOrigin] = None
|
52 |
+
# The name of the draft to put the example in. If None, puts it in the main draft.
|
53 |
+
draft: Optional[DraftId] = DRAFT_MAIN
|
54 |
+
|
55 |
+
@validator('text')
|
56 |
+
def parse_text(cls, text: str) -> str:
|
57 |
+
"""Fixes surrogate errors in text: https://github.com/ijl/orjson/blob/master/README.md#str ."""
|
58 |
+
return text.encode('utf-8', 'replace').decode('utf-8')
|
59 |
+
|
60 |
+
|
61 |
+
class Example(ExampleIn):
|
62 |
+
"""A single example in a concept used for training a concept model."""
|
63 |
+
id: str
|
64 |
+
|
65 |
+
|
66 |
+
class Concept(BaseModel):
|
67 |
+
"""A concept is a collection of examples."""
|
68 |
+
# The namespace of the concept.
|
69 |
+
namespace: str
|
70 |
+
# The name of the concept.
|
71 |
+
concept_name: str
|
72 |
+
# The type of the data format that this concept represents.
|
73 |
+
type: SignalInputType
|
74 |
+
data: dict[str, Example]
|
75 |
+
version: int = 0
|
76 |
+
|
77 |
+
description: Optional[str] = None
|
78 |
+
|
79 |
+
def drafts(self) -> list[DraftId]:
|
80 |
+
"""Gets all the drafts for the concept."""
|
81 |
+
drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
|
82 |
+
for example in self.data.values():
|
83 |
+
if example.draft:
|
84 |
+
drafts.add(example.draft)
|
85 |
+
return list(sorted(drafts))
|
86 |
+
|
87 |
+
|
88 |
+
class OverallScore(str, Enum):
|
89 |
+
"""Enum holding the overall score."""
|
90 |
+
NOT_GOOD = 'not_good'
|
91 |
+
OK = 'ok'
|
92 |
+
GOOD = 'good'
|
93 |
+
VERY_GOOD = 'very_good'
|
94 |
+
GREAT = 'great'
|
95 |
+
|
96 |
+
|
97 |
+
def _get_overall_score(f1_score: float) -> OverallScore:
|
98 |
+
if f1_score < 0.5:
|
99 |
+
return OverallScore.NOT_GOOD
|
100 |
+
if f1_score < 0.8:
|
101 |
+
return OverallScore.OK
|
102 |
+
if f1_score < 0.9:
|
103 |
+
return OverallScore.GOOD
|
104 |
+
if f1_score < 0.95:
|
105 |
+
return OverallScore.VERY_GOOD
|
106 |
+
return OverallScore.GREAT
|
107 |
+
|
108 |
+
|
109 |
+
class ConceptMetrics(BaseModel):
|
110 |
+
"""Metrics for a concept."""
|
111 |
+
# The average F1 score for the concept computed using cross validation.
|
112 |
+
f1: float
|
113 |
+
precision: float
|
114 |
+
recall: float
|
115 |
+
roc_auc: float
|
116 |
+
overall: OverallScore
|
117 |
+
|
118 |
+
|
119 |
+
@dataclasses.dataclass
|
120 |
+
class LogisticEmbeddingModel:
|
121 |
+
"""A model that uses logistic regression with embeddings."""
|
122 |
+
|
123 |
+
_metrics: Optional[ConceptMetrics] = None
|
124 |
+
_threshold: float = 0.5
|
125 |
+
|
126 |
+
def __post_init__(self) -> None:
|
127 |
+
# See `notebooks/Toxicity.ipynb` for an example of training a concept model.
|
128 |
+
self._model = LogisticRegression(
|
129 |
+
class_weight='balanced', C=30, tol=1e-5, warm_start=True, max_iter=5_000, n_jobs=-1)
|
130 |
+
|
131 |
+
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
132 |
+
"""Get the scores for the provided embeddings."""
|
133 |
+
y_probs = self._model.predict_proba(embeddings)[:, 1]
|
134 |
+
# Map [0, threshold, 1] to [0, 0.5, 1].
|
135 |
+
interpolate_fn = interp1d([0, self._threshold, 1], [0, 0.4999, 1])
|
136 |
+
return interpolate_fn(y_probs)
|
137 |
+
|
138 |
+
def _setup_training(self, X_train: np.ndarray,
|
139 |
+
labels: Union[list[bool], np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
|
140 |
+
y_train = np.array(labels)
|
141 |
+
# Shuffle the data in unison.
|
142 |
+
p = np.random.permutation(len(X_train))
|
143 |
+
X_train = X_train[p]
|
144 |
+
y_train = y_train[p]
|
145 |
+
return X_train, y_train
|
146 |
+
|
147 |
+
def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None:
|
148 |
+
"""Fit the model to the provided embeddings and labels."""
|
149 |
+
label_set = set(labels)
|
150 |
+
if len(label_set) < 2:
|
151 |
+
dim = embeddings.shape[1]
|
152 |
+
random_vector = np.random.randn(dim).astype(np.float32)
|
153 |
+
random_vector /= np.linalg.norm(random_vector)
|
154 |
+
embeddings = np.vstack([embeddings, random_vector])
|
155 |
+
labels.append(False if True in label_set else True)
|
156 |
+
|
157 |
+
if len(labels) != len(embeddings):
|
158 |
+
raise ValueError(
|
159 |
+
f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
|
160 |
+
X_train, y_train = self._setup_training(embeddings, labels)
|
161 |
+
self._model.fit(X_train, y_train)
|
162 |
+
self._metrics, self._threshold = self._compute_metrics(embeddings, labels)
|
163 |
+
|
164 |
+
def _compute_metrics(self, embeddings: np.ndarray,
|
165 |
+
labels: list[bool]) -> tuple[Optional[ConceptMetrics], float]:
|
166 |
+
"""Return the concept metrics."""
|
167 |
+
labels_np = np.array(labels)
|
168 |
+
n_splits = min(len(labels_np), MAX_NUM_CROSS_VAL_MODELS)
|
169 |
+
fold = KFold(n_splits, shuffle=True, random_state=42)
|
170 |
+
|
171 |
+
def _fit_and_score(model: LogisticRegression, X_train: np.ndarray, y_train: np.ndarray,
|
172 |
+
X_test: np.ndarray, y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
173 |
+
if len(set(y_train)) < 2:
|
174 |
+
return np.array([]), np.array([])
|
175 |
+
model.fit(X_train, y_train)
|
176 |
+
y_pred = model.predict_proba(X_test)[:, 1]
|
177 |
+
return y_test, y_pred
|
178 |
+
|
179 |
+
# Compute the metrics for each validation fold in parallel.
|
180 |
+
jobs: list[Callable] = []
|
181 |
+
for (train_index, test_index) in fold.split(embeddings):
|
182 |
+
X_train, y_train = embeddings[train_index], labels_np[train_index]
|
183 |
+
X_train, y_train = self._setup_training(X_train, y_train)
|
184 |
+
X_test, y_test = embeddings[test_index], labels_np[test_index]
|
185 |
+
model = clone(self._model)
|
186 |
+
jobs.append(delayed(_fit_and_score)(model, X_train, y_train, X_test, y_test))
|
187 |
+
results = Parallel(n_jobs=-1)(jobs)
|
188 |
+
|
189 |
+
y_test = np.concatenate([y_test for y_test, _ in results], axis=0)
|
190 |
+
y_pred = np.concatenate([y_pred for _, y_pred in results], axis=0)
|
191 |
+
if len(set(y_test)) < 2:
|
192 |
+
return None, 0.5
|
193 |
+
roc_auc_val = roc_auc_score(y_test, y_pred)
|
194 |
+
precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
|
195 |
+
numerator = (1 + F_BETA_WEIGHT**2) * precision * recall
|
196 |
+
denom = (F_BETA_WEIGHT**2 * precision) + recall
|
197 |
+
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
|
198 |
+
max_f1: float = np.max(f1_scores)
|
199 |
+
max_f1_index = np.argmax(f1_scores)
|
200 |
+
max_f1_thresh: float = thresholds[max_f1_index]
|
201 |
+
max_f1_prec: float = precision[max_f1_index]
|
202 |
+
max_f1_recall: float = recall[max_f1_index]
|
203 |
+
metrics = ConceptMetrics(
|
204 |
+
f1=max_f1,
|
205 |
+
precision=max_f1_prec,
|
206 |
+
recall=max_f1_recall,
|
207 |
+
roc_auc=float(roc_auc_val),
|
208 |
+
overall=_get_overall_score(max_f1))
|
209 |
+
return metrics, max_f1_thresh
|
210 |
+
|
211 |
+
|
212 |
+
def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
|
213 |
+
"""Get the examples in the provided draft by overriding the main draft."""
|
214 |
+
draft_examples: dict[str, dict[str, Example]] = {}
|
215 |
+
for id, example in concept.data.items():
|
216 |
+
draft_examples.setdefault(example.draft or DRAFT_MAIN, {})[example.id] = example
|
217 |
+
|
218 |
+
if draft == DRAFT_MAIN:
|
219 |
+
return draft_examples.get(DRAFT_MAIN, {})
|
220 |
+
|
221 |
+
if draft not in draft_examples:
|
222 |
+
raise ValueError(
|
223 |
+
f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
|
224 |
+
|
225 |
+
# Map the text of the draft to its id so we can dedupe with main.
|
226 |
+
draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
|
227 |
+
|
228 |
+
# Write each of examples from main to the draft examples only if the text does not appear in the
|
229 |
+
# draft.
|
230 |
+
for id, example in draft_examples[DRAFT_MAIN].items():
|
231 |
+
if example.text not in draft_text_ids:
|
232 |
+
draft_examples[draft][id] = example
|
233 |
+
|
234 |
+
return draft_examples[draft]
|
235 |
+
|
236 |
+
|
237 |
+
@dataclasses.dataclass
|
238 |
+
class ConceptModel:
|
239 |
+
"""A concept model. Stores all concept model drafts and manages syncing."""
|
240 |
+
# The concept that this model is for.
|
241 |
+
namespace: str
|
242 |
+
concept_name: str
|
243 |
+
|
244 |
+
# The name of the embedding for this model.
|
245 |
+
embedding_name: str
|
246 |
+
version: int = 0
|
247 |
+
|
248 |
+
batch_size = 4096
|
249 |
+
|
250 |
+
# The following fields are excluded from JSON serialization, but still pickle-able.
|
251 |
+
# Maps a concept id to the embeddings.
|
252 |
+
_embeddings: dict[str, np.ndarray] = dataclasses.field(default_factory=dict)
|
253 |
+
_logistic_models: dict[DraftId, LogisticEmbeddingModel] = dataclasses.field(default_factory=dict)
|
254 |
+
|
255 |
+
def get_metrics(self, concept: Concept) -> Optional[ConceptMetrics]:
|
256 |
+
"""Return the metrics for this model."""
|
257 |
+
return self._get_logistic_model(DRAFT_MAIN)._metrics
|
258 |
+
|
259 |
+
def score_embeddings(self, draft: DraftId, embeddings: np.ndarray) -> np.ndarray:
|
260 |
+
"""Get the scores for the provided embeddings."""
|
261 |
+
return self._get_logistic_model(draft).score_embeddings(embeddings)
|
262 |
+
|
263 |
+
def coef(self, draft: DraftId) -> np.ndarray:
|
264 |
+
"""Get the coefficients of the underlying ML model."""
|
265 |
+
return self._get_logistic_model(draft)._model.coef_.reshape(-1)
|
266 |
+
|
267 |
+
def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
|
268 |
+
"""Get the logistic model for the provided draft."""
|
269 |
+
if draft not in self._logistic_models:
|
270 |
+
self._logistic_models[draft] = LogisticEmbeddingModel()
|
271 |
+
return self._logistic_models[draft]
|
272 |
+
|
273 |
+
def sync(self, concept: Concept) -> bool:
|
274 |
+
"""Update the model with the latest labeled concept data."""
|
275 |
+
if concept.version == self.version:
|
276 |
+
# The model is up to date.
|
277 |
+
return False
|
278 |
+
|
279 |
+
concept_path = (f'{self.namespace}/{self.concept_name}/'
|
280 |
+
f'{self.embedding_name}')
|
281 |
+
with DebugTimer(f'Computing embeddings for "{concept_path}"'):
|
282 |
+
self._compute_embeddings(concept)
|
283 |
+
|
284 |
+
# Fit each of the drafts, sort by draft name for deterministic behavior.
|
285 |
+
for draft in concept.drafts():
|
286 |
+
examples = draft_examples(concept, draft)
|
287 |
+
embeddings = np.array([self._embeddings[id] for id in examples.keys()])
|
288 |
+
labels = [example.label for example in examples.values()]
|
289 |
+
model = self._get_logistic_model(draft)
|
290 |
+
with DebugTimer(f'Fitting model for "{concept_path}"'):
|
291 |
+
model.fit(embeddings, labels)
|
292 |
+
|
293 |
+
# Synchronize the model version with the concept version.
|
294 |
+
self.version = concept.version
|
295 |
+
|
296 |
+
return True
|
297 |
+
|
298 |
+
def _compute_embeddings(self, concept: Concept) -> None:
|
299 |
+
signal_cls = get_signal_cls(self.embedding_name)
|
300 |
+
if not signal_cls:
|
301 |
+
raise ValueError(f'Embedding signal "{self.embedding_name}" not found in the registry.')
|
302 |
+
embedding_signal = signal_cls()
|
303 |
+
if not isinstance(embedding_signal, TextEmbeddingSignal):
|
304 |
+
raise ValueError(f'Only text embedding signals are currently supported for concepts. '
|
305 |
+
f'"{self.embedding_name}" is a {type(embedding_signal)}.')
|
306 |
+
|
307 |
+
embed_fn = get_embed_fn(self.embedding_name, split=False)
|
308 |
+
concept_embeddings: dict[str, np.ndarray] = {}
|
309 |
+
|
310 |
+
examples = concept.data.items()
|
311 |
+
if not examples:
|
312 |
+
raise ValueError(f'Cannot sync concept "{concept.concept_name}". It has no examples.')
|
313 |
+
|
314 |
+
# Compute the embeddings for the examples with cache miss.
|
315 |
+
texts_of_missing_embeddings: dict[str, str] = {}
|
316 |
+
for id, example in examples:
|
317 |
+
if id in self._embeddings:
|
318 |
+
# Cache hit.
|
319 |
+
concept_embeddings[id] = self._embeddings[id]
|
320 |
+
else:
|
321 |
+
# Cache miss.
|
322 |
+
# TODO(smilkov): Support images.
|
323 |
+
texts_of_missing_embeddings[id] = example.text or ''
|
324 |
+
|
325 |
+
missing_ids = texts_of_missing_embeddings.keys()
|
326 |
+
missing_embeddings = embed_fn(list(texts_of_missing_embeddings.values()))
|
327 |
+
|
328 |
+
for id, (embedding,) in zip(missing_ids, missing_embeddings):
|
329 |
+
concept_embeddings[id] = embedding['vector']
|
330 |
+
self._embeddings = concept_embeddings
|
lilac/concepts/db_concept.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The concept database."""
|
2 |
+
|
3 |
+
import abc
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import pickle
|
9 |
+
import shutil
|
10 |
+
|
11 |
+
# NOTE: We have to import the module for uuid so it can be mocked.
|
12 |
+
import uuid
|
13 |
+
from typing import Any, List, Optional, Union, cast
|
14 |
+
|
15 |
+
from pydantic import BaseModel
|
16 |
+
from typing_extensions import override
|
17 |
+
|
18 |
+
from ..auth import ConceptAuthorizationException, UserInfo
|
19 |
+
from ..env import data_path, env
|
20 |
+
from ..schema import SignalInputType
|
21 |
+
from ..signals.signal import get_signal_cls
|
22 |
+
from ..utils import delete_file, file_exists, open_file
|
23 |
+
from .concept import DRAFT_MAIN, Concept, ConceptModel, DraftId, Example, ExampleIn
|
24 |
+
|
25 |
+
CONCEPTS_DIR = 'concept'
|
26 |
+
CONCEPT_JSON_FILENAME = 'concept.json'
|
27 |
+
|
28 |
+
|
29 |
+
class ConceptNamespaceACL(BaseModel):
|
30 |
+
"""The access control list for a namespace."""
|
31 |
+
# Whether the current user can read concepts in the namespace.
|
32 |
+
read: bool
|
33 |
+
# Whether the current user can add concepts to the namespace.
|
34 |
+
write: bool
|
35 |
+
|
36 |
+
|
37 |
+
class ConceptACL(BaseModel):
|
38 |
+
"""The access control list for an individual concept."""
|
39 |
+
# Whether the current user can read the concept.
|
40 |
+
read: bool
|
41 |
+
# Whether the current user can edit the concept, including adding examples or deleting the
|
42 |
+
# concept.
|
43 |
+
write: bool
|
44 |
+
|
45 |
+
|
46 |
+
class ConceptInfo(BaseModel):
|
47 |
+
"""Information about a concept."""
|
48 |
+
namespace: str
|
49 |
+
name: str
|
50 |
+
description: Optional[str] = None
|
51 |
+
type: SignalInputType
|
52 |
+
drafts: list[DraftId]
|
53 |
+
|
54 |
+
acls: ConceptACL
|
55 |
+
|
56 |
+
|
57 |
+
class ConceptUpdate(BaseModel):
|
58 |
+
"""An update to a concept."""
|
59 |
+
# List of examples to be inserted.
|
60 |
+
insert: Optional[list[ExampleIn]] = []
|
61 |
+
|
62 |
+
# List of examples to be updated.
|
63 |
+
update: Optional[list[Example]] = []
|
64 |
+
|
65 |
+
# The ids of the examples to be removed.
|
66 |
+
remove: Optional[list[str]] = []
|
67 |
+
|
68 |
+
|
69 |
+
class ConceptDB(abc.ABC):
|
70 |
+
"""Interface for the concept database."""
|
71 |
+
|
72 |
+
@abc.abstractmethod
|
73 |
+
def list(self, user: Optional[UserInfo] = None) -> list[ConceptInfo]:
|
74 |
+
"""List all the concepts."""
|
75 |
+
pass
|
76 |
+
|
77 |
+
@abc.abstractmethod
|
78 |
+
def namespace_acls(self, namespace: str, user: Optional[UserInfo] = None) -> ConceptNamespaceACL:
|
79 |
+
"""Return the ACL for a namespace."""
|
80 |
+
pass
|
81 |
+
|
82 |
+
@abc.abstractmethod
|
83 |
+
def concept_acls(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> ConceptACL:
|
84 |
+
"""Return the ACL for a concept."""
|
85 |
+
pass
|
86 |
+
|
87 |
+
@abc.abstractmethod
|
88 |
+
def get(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> Optional[Concept]:
|
89 |
+
"""Return a concept or None if there isn't one."""
|
90 |
+
pass
|
91 |
+
|
92 |
+
@abc.abstractmethod
|
93 |
+
def create(self,
|
94 |
+
namespace: str,
|
95 |
+
name: str,
|
96 |
+
type: SignalInputType,
|
97 |
+
description: Optional[str] = None,
|
98 |
+
user: Optional[UserInfo] = None) -> Concept:
|
99 |
+
"""Create a concept.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
namespace: The namespace of the concept.
|
103 |
+
name: The name of the concept.
|
104 |
+
type: The input type of the concept.
|
105 |
+
description: The description of the concept.
|
106 |
+
user: The user creating the concept, if authentication is enabled.
|
107 |
+
"""
|
108 |
+
pass
|
109 |
+
|
110 |
+
@abc.abstractmethod
|
111 |
+
def edit(self,
|
112 |
+
namespace: str,
|
113 |
+
name: str,
|
114 |
+
change: ConceptUpdate,
|
115 |
+
user: Optional[UserInfo] = None) -> Concept:
|
116 |
+
"""Edit a concept. If the concept doesn't exist, throw an error."""
|
117 |
+
pass
|
118 |
+
|
119 |
+
@abc.abstractmethod
|
120 |
+
def remove(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> None:
|
121 |
+
"""Remove a concept."""
|
122 |
+
pass
|
123 |
+
|
124 |
+
@abc.abstractmethod
|
125 |
+
def merge_draft(self,
|
126 |
+
namespace: str,
|
127 |
+
name: str,
|
128 |
+
draft: DraftId,
|
129 |
+
user: Optional[UserInfo] = None) -> Concept:
|
130 |
+
"""Merge a draft concept.."""
|
131 |
+
pass
|
132 |
+
|
133 |
+
|
134 |
+
class ConceptModelDB(abc.ABC):
|
135 |
+
"""Interface for the concept model database."""
|
136 |
+
|
137 |
+
_concept_db: ConceptDB
|
138 |
+
|
139 |
+
def __init__(self, concept_db: ConceptDB) -> None:
|
140 |
+
self._concept_db = concept_db
|
141 |
+
|
142 |
+
@abc.abstractmethod
|
143 |
+
def create(self,
|
144 |
+
namespace: str,
|
145 |
+
concept_name: str,
|
146 |
+
embedding_name: str,
|
147 |
+
user: Optional[UserInfo] = None) -> ConceptModel:
|
148 |
+
"""Create the concept model."""
|
149 |
+
pass
|
150 |
+
|
151 |
+
@abc.abstractmethod
|
152 |
+
def get(self,
|
153 |
+
namespace: str,
|
154 |
+
concept_name: str,
|
155 |
+
embedding_name: str,
|
156 |
+
user: Optional[UserInfo] = None) -> Optional[ConceptModel]:
|
157 |
+
"""Get the model associated with the provided concept the embedding.
|
158 |
+
|
159 |
+
Returns None if the model does not exist.
|
160 |
+
"""
|
161 |
+
pass
|
162 |
+
|
163 |
+
@abc.abstractmethod
|
164 |
+
def _save(self, model: ConceptModel) -> None:
|
165 |
+
"""Save the concept model."""
|
166 |
+
pass
|
167 |
+
|
168 |
+
def in_sync(self, model: ConceptModel, user: Optional[UserInfo] = None) -> bool:
|
169 |
+
"""Return True if the model is up to date with the concept."""
|
170 |
+
concept = self._concept_db.get(model.namespace, model.concept_name, user=user)
|
171 |
+
if not concept:
|
172 |
+
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
173 |
+
return concept.version == model.version
|
174 |
+
|
175 |
+
def sync(self, model: ConceptModel, user: Optional[UserInfo] = None) -> bool:
|
176 |
+
"""Sync the concept model. Returns true if the model was updated."""
|
177 |
+
concept = self._concept_db.get(model.namespace, model.concept_name, user=user)
|
178 |
+
if not concept:
|
179 |
+
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
180 |
+
model_updated = model.sync(concept)
|
181 |
+
if model_updated:
|
182 |
+
self._save(model)
|
183 |
+
return model_updated
|
184 |
+
|
185 |
+
@abc.abstractmethod
|
186 |
+
def remove(self, namespace: str, concept_name: str, embedding_name: str) -> None:
|
187 |
+
"""Remove the model of a concept."""
|
188 |
+
pass
|
189 |
+
|
190 |
+
@abc.abstractmethod
|
191 |
+
def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
|
192 |
+
"""List all the models associated with a concept."""
|
193 |
+
pass
|
194 |
+
|
195 |
+
|
196 |
+
class DiskConceptModelDB(ConceptModelDB):
|
197 |
+
"""Interface for the concept model database."""
|
198 |
+
|
199 |
+
def __init__(self,
|
200 |
+
concept_db: ConceptDB,
|
201 |
+
base_dir: Optional[Union[str, pathlib.Path]] = None) -> None:
|
202 |
+
super().__init__(concept_db)
|
203 |
+
self._base_dir = base_dir
|
204 |
+
|
205 |
+
def _get_base_dir(self) -> str:
|
206 |
+
return str(self._base_dir) if self._base_dir else data_path()
|
207 |
+
|
208 |
+
@override
|
209 |
+
def create(self,
|
210 |
+
namespace: str,
|
211 |
+
concept_name: str,
|
212 |
+
embedding_name: str,
|
213 |
+
user: Optional[UserInfo] = None) -> ConceptModel:
|
214 |
+
if self.get(namespace, concept_name, embedding_name, user=user):
|
215 |
+
raise ValueError('Concept model already exists.')
|
216 |
+
concept = self._concept_db.get(namespace, concept_name, user=user)
|
217 |
+
if not concept:
|
218 |
+
raise ValueError(f'Concept "{namespace}/{concept_name}" does not exist.')
|
219 |
+
|
220 |
+
return ConceptModel(
|
221 |
+
namespace=namespace, concept_name=concept_name, embedding_name=embedding_name)
|
222 |
+
|
223 |
+
@override
|
224 |
+
def get(self,
|
225 |
+
namespace: str,
|
226 |
+
concept_name: str,
|
227 |
+
embedding_name: str,
|
228 |
+
user: Optional[UserInfo] = None) -> Optional[ConceptModel]:
|
229 |
+
# Make sure the concept exists.
|
230 |
+
concept = self._concept_db.get(namespace, concept_name, user=user)
|
231 |
+
if not concept:
|
232 |
+
raise ValueError(f'Concept "{namespace}/{concept_name}" does not exist.')
|
233 |
+
|
234 |
+
# Make sure that the embedding signal exists.
|
235 |
+
if not get_signal_cls(embedding_name):
|
236 |
+
raise ValueError(f'Embedding signal "{embedding_name}" not found in the registry.')
|
237 |
+
|
238 |
+
concept_model_path = _concept_model_path(self._get_base_dir(), namespace, concept_name,
|
239 |
+
embedding_name)
|
240 |
+
if not file_exists(concept_model_path):
|
241 |
+
return None
|
242 |
+
|
243 |
+
with open_file(concept_model_path, 'rb') as f:
|
244 |
+
return pickle.load(f)
|
245 |
+
|
246 |
+
def _save(self, model: ConceptModel) -> None:
|
247 |
+
"""Save the concept model."""
|
248 |
+
concept_model_path = _concept_model_path(self._get_base_dir(), model.namespace,
|
249 |
+
model.concept_name, model.embedding_name)
|
250 |
+
with open_file(concept_model_path, 'wb') as f:
|
251 |
+
pickle.dump(model, f)
|
252 |
+
|
253 |
+
@override
|
254 |
+
def remove(self,
|
255 |
+
namespace: str,
|
256 |
+
concept_name: str,
|
257 |
+
embedding_name: str,
|
258 |
+
user: Optional[UserInfo] = None) -> None:
|
259 |
+
concept_model_path = _concept_model_path(self._get_base_dir(), namespace, concept_name,
|
260 |
+
embedding_name)
|
261 |
+
|
262 |
+
if not file_exists(concept_model_path):
|
263 |
+
raise ValueError(f'Concept model {namespace}/{concept_name}/{embedding_name} does not exist.')
|
264 |
+
|
265 |
+
delete_file(concept_model_path)
|
266 |
+
|
267 |
+
@override
|
268 |
+
def get_models(self,
|
269 |
+
namespace: str,
|
270 |
+
concept_name: str,
|
271 |
+
user: Optional[UserInfo] = None) -> list[ConceptModel]:
|
272 |
+
"""List all the models associated with a concept."""
|
273 |
+
model_files = glob.iglob(
|
274 |
+
os.path.join(get_concept_output_dir(self._get_base_dir(), namespace, concept_name), '*.pkl'))
|
275 |
+
models: list[ConceptModel] = []
|
276 |
+
for model_file in model_files:
|
277 |
+
embedding_name = os.path.basename(model_file)[:-len('.pkl')]
|
278 |
+
model = self.get(namespace, concept_name, embedding_name, user=user)
|
279 |
+
if model:
|
280 |
+
models.append(model)
|
281 |
+
return models
|
282 |
+
|
283 |
+
|
284 |
+
def get_concept_output_dir(base_dir: str, namespace: str, name: str) -> str:
|
285 |
+
"""Return the output directory for a given concept."""
|
286 |
+
return os.path.join(base_dir, CONCEPTS_DIR, namespace, name)
|
287 |
+
|
288 |
+
|
289 |
+
def _concept_json_path(base_dir: str, namespace: str, name: str) -> str:
|
290 |
+
return os.path.join(get_concept_output_dir(base_dir, namespace, name), CONCEPT_JSON_FILENAME)
|
291 |
+
|
292 |
+
|
293 |
+
def _concept_model_path(base_dir: str, namespace: str, concept_name: str,
|
294 |
+
embedding_name: str) -> str:
|
295 |
+
|
296 |
+
return os.path.join(
|
297 |
+
get_concept_output_dir(base_dir, namespace, concept_name), f'{embedding_name}.pkl')
|
298 |
+
|
299 |
+
|
300 |
+
class DiskConceptDB(ConceptDB):
|
301 |
+
"""A concept database."""
|
302 |
+
|
303 |
+
def __init__(self, base_dir: Optional[Union[str, pathlib.Path]] = None) -> None:
|
304 |
+
self._base_dir = base_dir
|
305 |
+
|
306 |
+
def _get_base_dir(self) -> str:
|
307 |
+
return str(self._base_dir) if self._base_dir else data_path()
|
308 |
+
|
309 |
+
@override
|
310 |
+
def namespace_acls(self, namespace: str, user: Optional[UserInfo] = None) -> ConceptNamespaceACL:
|
311 |
+
if not env('LILAC_AUTH_ENABLED'):
|
312 |
+
return ConceptNamespaceACL(read=True, write=True)
|
313 |
+
|
314 |
+
if namespace == 'lilac':
|
315 |
+
return ConceptNamespaceACL(read=True, write=False)
|
316 |
+
if user and user.id == namespace:
|
317 |
+
return ConceptNamespaceACL(read=True, write=True)
|
318 |
+
|
319 |
+
return ConceptNamespaceACL(read=False, write=False)
|
320 |
+
|
321 |
+
@override
|
322 |
+
def concept_acls(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> ConceptACL:
|
323 |
+
namespace_acls = self.namespace_acls(namespace, user=user)
|
324 |
+
# Concept ACL inherit from the namespace ACL. We currently don't have concept-specific
|
325 |
+
# ACL.
|
326 |
+
return ConceptACL(read=namespace_acls.read, write=namespace_acls.write)
|
327 |
+
|
328 |
+
@override
|
329 |
+
def list(self, user: Optional[UserInfo] = None) -> list[ConceptInfo]:
|
330 |
+
namespaces: Optional[list[str]] = None
|
331 |
+
if env('LILAC_AUTH_ENABLED'):
|
332 |
+
namespaces = ['lilac']
|
333 |
+
if user:
|
334 |
+
namespaces += [user.id]
|
335 |
+
|
336 |
+
# Read the concepts and return a ConceptInfo containing the namespace and name.
|
337 |
+
concept_infos = []
|
338 |
+
for root, _, files in os.walk(self._get_base_dir()):
|
339 |
+
for file in files:
|
340 |
+
if file == CONCEPT_JSON_FILENAME:
|
341 |
+
namespace, name = root.split('/')[-2:]
|
342 |
+
if namespaces and namespace not in namespaces:
|
343 |
+
# Ignore concepts that are not in the namespace, if provided.
|
344 |
+
continue
|
345 |
+
|
346 |
+
concept = cast(Concept, self.get(namespace, name, user=user))
|
347 |
+
concept_infos.append(
|
348 |
+
ConceptInfo(
|
349 |
+
namespace=namespace,
|
350 |
+
name=name,
|
351 |
+
description=concept.description,
|
352 |
+
type=SignalInputType.TEXT,
|
353 |
+
drafts=concept.drafts(),
|
354 |
+
acls=self.concept_acls(namespace, name, user=user)))
|
355 |
+
|
356 |
+
return concept_infos
|
357 |
+
|
358 |
+
@override
|
359 |
+
def get(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> Optional[Concept]:
|
360 |
+
# If the user does not have access to the concept, return None.
|
361 |
+
acls = self.concept_acls(namespace, name, user=user)
|
362 |
+
if not acls.read:
|
363 |
+
raise ConceptAuthorizationException(
|
364 |
+
f'Concept "{namespace}/{name}" does not exist or user does not have access.')
|
365 |
+
|
366 |
+
concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
|
367 |
+
if not file_exists(concept_json_path):
|
368 |
+
return None
|
369 |
+
|
370 |
+
with open_file(concept_json_path) as f:
|
371 |
+
obj: dict[str, Any] = json.load(f)
|
372 |
+
if 'namespace' not in obj:
|
373 |
+
obj['namespace'] = namespace
|
374 |
+
return Concept.parse_obj(obj)
|
375 |
+
|
376 |
+
@override
|
377 |
+
def create(self,
|
378 |
+
namespace: str,
|
379 |
+
name: str,
|
380 |
+
type: SignalInputType,
|
381 |
+
description: Optional[str] = None,
|
382 |
+
user: Optional[UserInfo] = None) -> Concept:
|
383 |
+
"""Create a concept."""
|
384 |
+
# If the user does not have access to the write to the concept namespace, throw.
|
385 |
+
acls = self.namespace_acls(namespace, user=user)
|
386 |
+
if not acls.write:
|
387 |
+
raise ConceptAuthorizationException(
|
388 |
+
f'Concept namespace "{namespace}" does not exist or user does not have access.')
|
389 |
+
|
390 |
+
concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
|
391 |
+
if file_exists(concept_json_path):
|
392 |
+
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
|
393 |
+
|
394 |
+
concept = Concept(
|
395 |
+
namespace=namespace, concept_name=name, type=type, data={}, description=description)
|
396 |
+
self._save(concept)
|
397 |
+
return concept
|
398 |
+
|
399 |
+
def _validate_examples(self, examples: List[Union[ExampleIn, Example]],
|
400 |
+
type: SignalInputType) -> None:
|
401 |
+
for example in examples:
|
402 |
+
inferred_type = 'text' if example.text else 'img'
|
403 |
+
if inferred_type != type:
|
404 |
+
raise ValueError(f'Example type "{inferred_type}" does not match concept type "{type}".')
|
405 |
+
|
406 |
+
@override
|
407 |
+
def edit(self,
|
408 |
+
namespace: str,
|
409 |
+
name: str,
|
410 |
+
change: ConceptUpdate,
|
411 |
+
user: Optional[UserInfo] = None) -> Concept:
|
412 |
+
# If the user does not have access to the concept, return None.
|
413 |
+
acls = self.concept_acls(namespace, name, user=user)
|
414 |
+
if not acls.write:
|
415 |
+
raise ConceptAuthorizationException(
|
416 |
+
f'Concept "{namespace}/{name}" does not exist or user does not have access.')
|
417 |
+
|
418 |
+
concept_json_path = _concept_json_path(self._get_base_dir(), namespace, name)
|
419 |
+
|
420 |
+
if not file_exists(concept_json_path):
|
421 |
+
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist. '
|
422 |
+
'Please call create() first.')
|
423 |
+
|
424 |
+
inserted_points = change.insert or []
|
425 |
+
updated_points = change.update or []
|
426 |
+
removed_points = change.remove or []
|
427 |
+
|
428 |
+
concept = cast(Concept, self.get(namespace, name, user=user))
|
429 |
+
|
430 |
+
self._validate_examples([*inserted_points, *updated_points], concept.type)
|
431 |
+
|
432 |
+
for remove_example in removed_points:
|
433 |
+
if remove_example not in concept.data:
|
434 |
+
raise ValueError(f'Example with id "{remove_example}" does not exist.')
|
435 |
+
concept.data.pop(remove_example)
|
436 |
+
|
437 |
+
for example in inserted_points:
|
438 |
+
id = uuid.uuid4().hex
|
439 |
+
concept.data[id] = Example(id=id, **example.dict())
|
440 |
+
|
441 |
+
for example in updated_points:
|
442 |
+
if example.id not in concept.data:
|
443 |
+
raise ValueError(f'Example with id "{example.id}" does not exist.')
|
444 |
+
|
445 |
+
# Remove the old example and make a new one with a new id to keep it functional.
|
446 |
+
concept.data.pop(example.id)
|
447 |
+
concept.data[example.id] = example.copy()
|
448 |
+
|
449 |
+
concept.version += 1
|
450 |
+
|
451 |
+
self._save(concept)
|
452 |
+
|
453 |
+
return concept
|
454 |
+
|
455 |
+
def _save(self, concept: Concept) -> None:
|
456 |
+
concept_json_path = _concept_json_path(self._get_base_dir(), concept.namespace,
|
457 |
+
concept.concept_name)
|
458 |
+
with open_file(concept_json_path, 'w') as f:
|
459 |
+
f.write(concept.json(exclude_none=True, indent=2, exclude_defaults=True))
|
460 |
+
|
461 |
+
@override
|
462 |
+
def remove(self, namespace: str, name: str, user: Optional[UserInfo] = None) -> None:
|
463 |
+
# If the user does not have access to the concept, return None.
|
464 |
+
acls = self.concept_acls(namespace, name, user=user)
|
465 |
+
if not acls.write:
|
466 |
+
raise ConceptAuthorizationException(
|
467 |
+
f'Concept "{namespace}/{name}" does not exist or user does not have access.')
|
468 |
+
|
469 |
+
concept_dir = get_concept_output_dir(self._get_base_dir(), namespace, name)
|
470 |
+
|
471 |
+
if not file_exists(concept_dir):
|
472 |
+
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
|
473 |
+
|
474 |
+
shutil.rmtree(concept_dir, ignore_errors=True)
|
475 |
+
|
476 |
+
@override
|
477 |
+
def merge_draft(self,
|
478 |
+
namespace: str,
|
479 |
+
name: str,
|
480 |
+
draft: DraftId,
|
481 |
+
user: Optional[UserInfo] = None) -> Concept:
|
482 |
+
"""Merge a draft concept."""
|
483 |
+
# If the user does not have access to the concept, return None.
|
484 |
+
acls = self.concept_acls(namespace, name, user=user)
|
485 |
+
if not acls.write:
|
486 |
+
raise ConceptAuthorizationException(
|
487 |
+
f'Concept "{namespace}/{name}" does not exist or user does not have access.')
|
488 |
+
|
489 |
+
concept = self.get(namespace, name, user=user)
|
490 |
+
if not concept:
|
491 |
+
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
|
492 |
+
|
493 |
+
if draft == DRAFT_MAIN:
|
494 |
+
return concept
|
495 |
+
|
496 |
+
# Map the text of examples in main so we can remove them if they are duplicates.
|
497 |
+
main_text_ids: dict[Optional[str], str] = {
|
498 |
+
example.text: id for id, example in concept.data.items() if example.draft == DRAFT_MAIN
|
499 |
+
}
|
500 |
+
|
501 |
+
draft_examples: dict[str, Example] = {
|
502 |
+
id: example for id, example in concept.data.items() if example.draft == draft
|
503 |
+
}
|
504 |
+
for example in draft_examples.values():
|
505 |
+
example.draft = DRAFT_MAIN
|
506 |
+
# Remove duplicates in main.
|
507 |
+
main_text_id = main_text_ids.get(example.text)
|
508 |
+
if main_text_id:
|
509 |
+
del concept.data[main_text_id]
|
510 |
+
|
511 |
+
concept.version += 1
|
512 |
+
|
513 |
+
self._save(concept)
|
514 |
+
|
515 |
+
return concept
|
516 |
+
|
517 |
+
|
518 |
+
# A singleton concept database.
|
519 |
+
DISK_CONCEPT_DB = DiskConceptDB()
|
520 |
+
DISK_CONCEPT_MODEL_DB = DiskConceptModelDB(DISK_CONCEPT_DB)
|
lilac/config.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configurations for a dataset run."""
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from pydantic import BaseModel, validator
|
5 |
+
|
6 |
+
from .data.dataset import DatasetSettings
|
7 |
+
from .schema import Path, PathTuple, normalize_path
|
8 |
+
from .signals.signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal
|
9 |
+
from .sources.source import Source
|
10 |
+
from .sources.source_registry import resolve_source
|
11 |
+
|
12 |
+
|
13 |
+
class SignalConfig(BaseModel):
|
14 |
+
"""Configures a signal on a source path."""
|
15 |
+
path: PathTuple
|
16 |
+
signal: Signal
|
17 |
+
|
18 |
+
@validator('path', pre=True)
|
19 |
+
def parse_path(cls, path: Path) -> PathTuple:
|
20 |
+
"""Parse a path."""
|
21 |
+
return normalize_path(path)
|
22 |
+
|
23 |
+
@validator('signal', pre=True)
|
24 |
+
def parse_signal(cls, signal: dict) -> Signal:
|
25 |
+
"""Parse a signal to its specific subclass instance."""
|
26 |
+
return resolve_signal(signal)
|
27 |
+
|
28 |
+
|
29 |
+
class EmbeddingConfig(BaseModel):
|
30 |
+
"""Configures an embedding on a source path."""
|
31 |
+
path: PathTuple
|
32 |
+
embedding: str
|
33 |
+
|
34 |
+
@validator('path', pre=True)
|
35 |
+
def parse_path(cls, path: Path) -> PathTuple:
|
36 |
+
"""Parse a path."""
|
37 |
+
return normalize_path(path)
|
38 |
+
|
39 |
+
@validator('embedding', pre=True)
|
40 |
+
def validate_embedding(cls, embedding: str) -> str:
|
41 |
+
"""Validate the embedding is registered."""
|
42 |
+
get_signal_by_type(embedding, TextEmbeddingSignal)
|
43 |
+
return embedding
|
44 |
+
|
45 |
+
|
46 |
+
class DatasetConfig(BaseModel):
|
47 |
+
"""Configures a dataset with a source and transformations."""
|
48 |
+
# The namespace and name of the dataset.
|
49 |
+
namespace: str
|
50 |
+
name: str
|
51 |
+
|
52 |
+
# The source configuration.
|
53 |
+
source: Source
|
54 |
+
|
55 |
+
# Model configuration: embeddings and signals on paths.
|
56 |
+
embeddings: Optional[list[EmbeddingConfig]]
|
57 |
+
# When defined, uses this list of signals instead of running all signals.
|
58 |
+
signals: Optional[list[SignalConfig]]
|
59 |
+
|
60 |
+
# Dataset settings, default embeddings and UI settings like media paths.
|
61 |
+
settings: Optional[DatasetSettings]
|
62 |
+
|
63 |
+
@validator('source', pre=True)
|
64 |
+
def parse_source(cls, source: dict) -> Source:
|
65 |
+
"""Parse a source to its specific subclass instance."""
|
66 |
+
return resolve_source(source)
|
67 |
+
|
68 |
+
|
69 |
+
class Config(BaseModel):
|
70 |
+
"""Configures a set of datasets for a lilac instance."""
|
71 |
+
datasets: list[DatasetConfig]
|
72 |
+
|
73 |
+
# When defined, uses this list of signals to run over every dataset, over all media paths, unless
|
74 |
+
# signals is overridden by a specific dataset.
|
75 |
+
signals: list[Signal] = []
|
76 |
+
|
77 |
+
@validator('signals', pre=True)
|
78 |
+
def parse_signal(cls, signals: list[dict]) -> list[Signal]:
|
79 |
+
"""Parse alist of signals to their specific subclass instances."""
|
80 |
+
return [resolve_signal(signal) for signal in signals]
|
lilac/conftest.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fixtures for dataset tests."""
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from typing import Generator, Optional, Type
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
from pytest_mock import MockerFixture
|
8 |
+
|
9 |
+
from .data.dataset import Dataset
|
10 |
+
from .data.dataset_duckdb import DatasetDuckDB
|
11 |
+
from .data.dataset_test_utils import make_dataset
|
12 |
+
from .db_manager import set_default_dataset_cls
|
13 |
+
from .schema import Item, Schema
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.fixture(scope='function', params=[DatasetDuckDB])
|
17 |
+
def make_test_data(tmp_path: pathlib.Path, mocker: MockerFixture,
|
18 |
+
request: pytest.FixtureRequest) -> Generator:
|
19 |
+
"""A pytest fixture for creating temporary test datasets."""
|
20 |
+
mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)})
|
21 |
+
dataset_cls: Type[Dataset] = request.param
|
22 |
+
set_default_dataset_cls(dataset_cls)
|
23 |
+
|
24 |
+
def _make_test_data(items: list[Item], schema: Optional[Schema] = None) -> Dataset:
|
25 |
+
return make_dataset(dataset_cls, tmp_path, items, schema)
|
26 |
+
|
27 |
+
# Return the factory for datasets that test methods can use.
|
28 |
+
yield _make_test_data
|
lilac/data/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import Column, ConceptQuery, KeywordQuery, Search, SemanticQuery
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'Column',
|
5 |
+
'Search',
|
6 |
+
'KeywordQuery',
|
7 |
+
'ConceptQuery',
|
8 |
+
'SemanticQuery',
|
9 |
+
]
|
lilac/data/dataset.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The interface for the database."""
|
2 |
+
import abc
|
3 |
+
import enum
|
4 |
+
import pathlib
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from datetime import datetime
|
7 |
+
from typing import Any, Iterator, Literal, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from pydantic import BaseModel
|
11 |
+
from pydantic import Field as PydanticField
|
12 |
+
from pydantic import StrictBool, StrictBytes, StrictFloat, StrictInt, StrictStr, validator
|
13 |
+
|
14 |
+
from ..auth import UserInfo
|
15 |
+
from ..schema import VALUE_KEY, Bin, DataType, Path, PathTuple, Schema, normalize_path
|
16 |
+
from ..signals.signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal
|
17 |
+
from ..tasks import TaskStepId
|
18 |
+
|
19 |
+
# Threshold for rejecting certain queries (e.g. group by) for columns with large cardinality.
|
20 |
+
TOO_MANY_DISTINCT = 1_000_000
|
21 |
+
|
22 |
+
|
23 |
+
class SelectRowsResult:
|
24 |
+
"""The result of a select rows query."""
|
25 |
+
|
26 |
+
def __init__(self, df: pd.DataFrame, total_num_rows: int) -> None:
|
27 |
+
"""Initialize the result."""
|
28 |
+
self._df = df
|
29 |
+
self.total_num_rows = total_num_rows
|
30 |
+
|
31 |
+
def __iter__(self) -> Iterator:
|
32 |
+
return (row.to_dict() for _, row in self._df.iterrows())
|
33 |
+
|
34 |
+
def df(self) -> pd.DataFrame:
|
35 |
+
"""Convert the result to a pandas DataFrame."""
|
36 |
+
return self._df
|
37 |
+
|
38 |
+
|
39 |
+
class StatsResult(BaseModel):
|
40 |
+
"""The result of a stats() query."""
|
41 |
+
path: PathTuple
|
42 |
+
# The number of leaf values.
|
43 |
+
total_count: int
|
44 |
+
# The approximate number of distinct leaf values.
|
45 |
+
approx_count_distinct: int
|
46 |
+
|
47 |
+
# Defined for ordinal features.
|
48 |
+
min_val: Optional[Union[float, datetime]] = None
|
49 |
+
max_val: Optional[Union[float, datetime]] = None
|
50 |
+
|
51 |
+
# Defined for text features.
|
52 |
+
avg_text_length: Optional[float] = None
|
53 |
+
|
54 |
+
|
55 |
+
class MediaResult(BaseModel):
|
56 |
+
"""The result of a media() query."""
|
57 |
+
data: bytes
|
58 |
+
|
59 |
+
|
60 |
+
class BinaryOp(str, enum.Enum):
|
61 |
+
"""The comparison operator between a column and a feature value."""
|
62 |
+
EQUALS = 'equals'
|
63 |
+
NOT_EQUAL = 'not_equal'
|
64 |
+
GREATER = 'greater'
|
65 |
+
GREATER_EQUAL = 'greater_equal'
|
66 |
+
LESS = 'less'
|
67 |
+
LESS_EQUAL = 'less_equal'
|
68 |
+
|
69 |
+
|
70 |
+
SearchType = Union[Literal['keyword'], Literal['semantic'], Literal['concept']]
|
71 |
+
|
72 |
+
|
73 |
+
class UnaryOp(str, enum.Enum):
|
74 |
+
"""A unary operator on a feature."""
|
75 |
+
EXISTS = 'exists'
|
76 |
+
|
77 |
+
|
78 |
+
class ListOp(str, enum.Enum):
|
79 |
+
"""A list operator on a feature."""
|
80 |
+
IN = 'in'
|
81 |
+
|
82 |
+
|
83 |
+
class SortOrder(str, enum.Enum):
|
84 |
+
"""The sort order for a database query."""
|
85 |
+
DESC = 'DESC'
|
86 |
+
ASC = 'ASC'
|
87 |
+
|
88 |
+
|
89 |
+
class GroupsSortBy(str, enum.Enum):
|
90 |
+
"""The sort for groups queries.
|
91 |
+
|
92 |
+
Either "count" which sorts by the count of feature value, or "value" which sorts by the
|
93 |
+
feature value itself.
|
94 |
+
"""
|
95 |
+
COUNT = 'count'
|
96 |
+
VALUE = 'value'
|
97 |
+
|
98 |
+
|
99 |
+
class SortResult(BaseModel):
|
100 |
+
"""The information about what is sorted after combining searches and explicit sorts."""
|
101 |
+
# The column that was sorted.
|
102 |
+
path: PathTuple
|
103 |
+
# The sort order.
|
104 |
+
order: SortOrder
|
105 |
+
# The alias of the column if it was aliased.
|
106 |
+
alias: Optional[str] = None
|
107 |
+
# The search index if the sort is by a search.
|
108 |
+
search_index: Optional[int] = None
|
109 |
+
|
110 |
+
|
111 |
+
class SearchResultInfo(BaseModel):
|
112 |
+
"""The resulting sort order returned by the select rows schema."""
|
113 |
+
# The input path to the search.
|
114 |
+
search_path: PathTuple
|
115 |
+
# The resulting column that was searched.
|
116 |
+
result_path: PathTuple
|
117 |
+
# The alias of the UDF.
|
118 |
+
alias: Optional[str] = None
|
119 |
+
|
120 |
+
|
121 |
+
class SelectRowsSchemaUDF(BaseModel):
|
122 |
+
"""The UDF for a select rows schema query."""
|
123 |
+
path: PathTuple
|
124 |
+
alias: Optional[str] = None
|
125 |
+
|
126 |
+
|
127 |
+
class SelectRowsSchemaResult(BaseModel):
|
128 |
+
"""The result of a select rows schema query."""
|
129 |
+
data_schema: Schema
|
130 |
+
udfs: list[SelectRowsSchemaUDF] = []
|
131 |
+
search_results: list[SearchResultInfo] = []
|
132 |
+
sorts: Optional[list[SortResult]] = None
|
133 |
+
|
134 |
+
|
135 |
+
class Column(BaseModel):
|
136 |
+
"""A column in the dataset."""
|
137 |
+
path: PathTuple
|
138 |
+
alias: Optional[str] = None # This is the renamed column during querying and response.
|
139 |
+
|
140 |
+
# Defined when the feature is another column.
|
141 |
+
signal_udf: Optional[Signal] = None
|
142 |
+
|
143 |
+
class Config:
|
144 |
+
smart_union = True
|
145 |
+
|
146 |
+
def __init__(self,
|
147 |
+
path: Path,
|
148 |
+
alias: Optional[str] = None,
|
149 |
+
signal_udf: Optional[Signal] = None,
|
150 |
+
**kwargs: Any):
|
151 |
+
"""Initialize a column. We override __init__ to allow positional arguments for brevity."""
|
152 |
+
super().__init__(path=normalize_path(path), alias=alias, signal_udf=signal_udf, **kwargs)
|
153 |
+
|
154 |
+
@validator('signal_udf', pre=True)
|
155 |
+
def parse_signal_udf(cls, signal_udf: Optional[dict]) -> Optional[Signal]:
|
156 |
+
"""Parse a signal to its specific subclass instance."""
|
157 |
+
if not signal_udf:
|
158 |
+
return None
|
159 |
+
return resolve_signal(signal_udf)
|
160 |
+
|
161 |
+
|
162 |
+
ColumnId = Union[Path, Column]
|
163 |
+
|
164 |
+
|
165 |
+
class DatasetUISettings(BaseModel):
|
166 |
+
"""The UI persistent settings for a dataset."""
|
167 |
+
media_paths: list[PathTuple] = []
|
168 |
+
markdown_paths: list[PathTuple] = []
|
169 |
+
|
170 |
+
@validator('media_paths', pre=True)
|
171 |
+
def parse_media_paths(cls, media_paths: list) -> list:
|
172 |
+
"""Parse a path, ensuring it is a tuple."""
|
173 |
+
return [normalize_path(path) for path in media_paths]
|
174 |
+
|
175 |
+
|
176 |
+
class DatasetSettings(BaseModel):
|
177 |
+
"""The persistent settings for a dataset."""
|
178 |
+
ui: Optional[DatasetUISettings] = None
|
179 |
+
preferred_embedding: Optional[str] = None
|
180 |
+
|
181 |
+
|
182 |
+
class DatasetManifest(BaseModel):
|
183 |
+
"""The manifest for a dataset."""
|
184 |
+
namespace: str
|
185 |
+
dataset_name: str
|
186 |
+
data_schema: Schema
|
187 |
+
# Number of items in the dataset.
|
188 |
+
num_items: int
|
189 |
+
|
190 |
+
|
191 |
+
def column_from_identifier(column: ColumnId) -> Column:
|
192 |
+
"""Create a column from a column identifier."""
|
193 |
+
if isinstance(column, Column):
|
194 |
+
return column.copy()
|
195 |
+
return Column(path=column)
|
196 |
+
|
197 |
+
|
198 |
+
FeatureValue = Union[StrictInt, StrictFloat, StrictBool, StrictStr, StrictBytes, datetime]
|
199 |
+
FeatureListValue = list[StrictStr]
|
200 |
+
BinaryFilterTuple = tuple[Path, BinaryOp, FeatureValue]
|
201 |
+
ListFilterTuple = tuple[Path, ListOp, FeatureListValue]
|
202 |
+
UnaryFilterTuple = tuple[Path, UnaryOp]
|
203 |
+
|
204 |
+
FilterOp = Union[BinaryOp, UnaryOp, ListOp]
|
205 |
+
|
206 |
+
|
207 |
+
class SelectGroupsResult(BaseModel):
|
208 |
+
"""The result of a select groups query."""
|
209 |
+
too_many_distinct: bool
|
210 |
+
counts: list[tuple[Optional[FeatureValue], int]]
|
211 |
+
bins: Optional[list[Bin]] = None
|
212 |
+
|
213 |
+
|
214 |
+
class Filter(BaseModel):
|
215 |
+
"""A filter on a column."""
|
216 |
+
path: PathTuple
|
217 |
+
op: FilterOp
|
218 |
+
value: Optional[Union[FeatureValue, FeatureListValue]] = None
|
219 |
+
|
220 |
+
|
221 |
+
FilterLike = Union[Filter, BinaryFilterTuple, UnaryFilterTuple, ListFilterTuple]
|
222 |
+
|
223 |
+
SearchValue = StrictStr
|
224 |
+
|
225 |
+
|
226 |
+
class KeywordQuery(BaseModel):
|
227 |
+
"""A keyword search query on a column."""
|
228 |
+
type: Literal['keyword'] = 'keyword'
|
229 |
+
search: SearchValue
|
230 |
+
|
231 |
+
|
232 |
+
class SemanticQuery(BaseModel):
|
233 |
+
"""A semantic search on a column."""
|
234 |
+
type: Literal['semantic'] = 'semantic'
|
235 |
+
search: SearchValue
|
236 |
+
embedding: str
|
237 |
+
|
238 |
+
|
239 |
+
class ConceptQuery(BaseModel):
|
240 |
+
"""A concept search query on a column."""
|
241 |
+
type: Literal['concept'] = 'concept'
|
242 |
+
concept_namespace: str
|
243 |
+
concept_name: str
|
244 |
+
embedding: str
|
245 |
+
|
246 |
+
|
247 |
+
class Search(BaseModel):
|
248 |
+
"""A search on a column."""
|
249 |
+
path: Path
|
250 |
+
query: Union[KeywordQuery, SemanticQuery, ConceptQuery] = PydanticField(discriminator='type')
|
251 |
+
|
252 |
+
|
253 |
+
class Dataset(abc.ABC):
|
254 |
+
"""The database implementation to query a dataset."""
|
255 |
+
|
256 |
+
namespace: str
|
257 |
+
dataset_name: str
|
258 |
+
|
259 |
+
def __init__(self, namespace: str, dataset_name: str):
|
260 |
+
"""Initialize a dataset.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
namespace: The dataset namespace.
|
264 |
+
dataset_name: The dataset name.
|
265 |
+
"""
|
266 |
+
self.namespace = namespace
|
267 |
+
self.dataset_name = dataset_name
|
268 |
+
|
269 |
+
@abc.abstractmethod
|
270 |
+
def delete(self) -> None:
|
271 |
+
"""Deletes the dataset."""
|
272 |
+
pass
|
273 |
+
|
274 |
+
@abc.abstractmethod
|
275 |
+
def manifest(self) -> DatasetManifest:
|
276 |
+
"""Return the manifest for the dataset."""
|
277 |
+
pass
|
278 |
+
|
279 |
+
@abc.abstractmethod
|
280 |
+
def settings(self) -> DatasetSettings:
|
281 |
+
"""Return the persistent settings for the dataset."""
|
282 |
+
pass
|
283 |
+
|
284 |
+
@abc.abstractmethod
|
285 |
+
def update_settings(self, settings: DatasetSettings) -> None:
|
286 |
+
"""Update the settings for the dataset."""
|
287 |
+
pass
|
288 |
+
|
289 |
+
@abc.abstractmethod
|
290 |
+
def compute_signal(self,
|
291 |
+
signal: Signal,
|
292 |
+
leaf_path: Path,
|
293 |
+
task_step_id: Optional[TaskStepId] = None) -> None:
|
294 |
+
"""Compute a signal for a column.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
signal: The signal to compute over the given columns.
|
298 |
+
leaf_path: The leaf path to compute the signal on.
|
299 |
+
task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the
|
300 |
+
progress of the task.
|
301 |
+
"""
|
302 |
+
pass
|
303 |
+
|
304 |
+
def compute_embedding(self,
|
305 |
+
embedding: str,
|
306 |
+
path: Path,
|
307 |
+
task_step_id: Optional[TaskStepId] = None) -> None:
|
308 |
+
"""Compute an embedding for a given field path."""
|
309 |
+
signal = get_signal_by_type(embedding, TextEmbeddingSignal)()
|
310 |
+
self.compute_signal(signal, path)
|
311 |
+
|
312 |
+
@abc.abstractmethod
|
313 |
+
def delete_signal(self, signal_path: Path) -> None:
|
314 |
+
"""Delete a computed signal from the dataset.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
signal_path: The path holding the computed data of the signal.
|
318 |
+
"""
|
319 |
+
pass
|
320 |
+
|
321 |
+
@abc.abstractmethod
|
322 |
+
def select_groups(
|
323 |
+
self,
|
324 |
+
leaf_path: Path,
|
325 |
+
filters: Optional[Sequence[FilterLike]] = None,
|
326 |
+
sort_by: Optional[GroupsSortBy] = None,
|
327 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC,
|
328 |
+
limit: Optional[int] = None,
|
329 |
+
bins: Optional[Union[Sequence[Bin], Sequence[float]]] = None) -> SelectGroupsResult:
|
330 |
+
"""Select grouped columns to power a histogram.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
leaf_path: The leaf path to group by. The path can be a dot-seperated string path, or a tuple
|
334 |
+
of fields.
|
335 |
+
filters: The filters to apply to the query.
|
336 |
+
sort_by: What to sort by, either "count" or "value".
|
337 |
+
sort_order: The sort order.
|
338 |
+
limit: The maximum number of rows to return.
|
339 |
+
bins: The bins to use when bucketizing a float column.
|
340 |
+
|
341 |
+
Returns
|
342 |
+
A `SelectGroupsResult` iterator where each row is a group.
|
343 |
+
"""
|
344 |
+
raise NotImplementedError
|
345 |
+
|
346 |
+
@abc.abstractmethod
|
347 |
+
def select_rows(self,
|
348 |
+
columns: Optional[Sequence[ColumnId]] = None,
|
349 |
+
searches: Optional[Sequence[Search]] = None,
|
350 |
+
filters: Optional[Sequence[FilterLike]] = None,
|
351 |
+
sort_by: Optional[Sequence[Path]] = None,
|
352 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC,
|
353 |
+
limit: Optional[int] = 100,
|
354 |
+
offset: Optional[int] = 0,
|
355 |
+
task_step_id: Optional[TaskStepId] = None,
|
356 |
+
resolve_span: bool = False,
|
357 |
+
combine_columns: bool = False,
|
358 |
+
user: Optional[UserInfo] = None) -> SelectRowsResult:
|
359 |
+
"""Select grouped columns to power a histogram.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
columns: The columns to select. A column is an instance of `Column` which can either
|
363 |
+
define a path to a feature, or a column with an applied Transform, e.g. a Concept. If none,
|
364 |
+
it selects all columns.
|
365 |
+
searches: The searches to apply to the query.
|
366 |
+
filters: The filters to apply to the query.
|
367 |
+
sort_by: An ordered list of what to sort by. When defined, this is a list of aliases of column
|
368 |
+
names defined by the "alias" field in Column. If no alias is provided for a column, an
|
369 |
+
automatic alias is generated by combining each path element with a "."
|
370 |
+
For example: e.g. ('person', 'name') => person.name. For columns that are transform columns,
|
371 |
+
an alias must be provided explicitly. When sorting by a (nested) list of values, the sort
|
372 |
+
takes the minumum value when `sort_order` is `ASC`, and the maximum value when `sort_order`
|
373 |
+
is `DESC`.
|
374 |
+
sort_order: The sort order.
|
375 |
+
limit: The maximum number of rows to return.
|
376 |
+
offset: The offset to start returning rows from.
|
377 |
+
task_step_id: The TaskManager `task_step_id` for this process run. This is used to update the
|
378 |
+
progress.
|
379 |
+
resolve_span: Whether to resolve the span of the row.
|
380 |
+
combine_columns: Whether to combine columns into a single object. The object will be pruned
|
381 |
+
to only include sub-fields that correspond to the requested columns.
|
382 |
+
user: The authenticated user, if auth is enabled and the user is logged in. This is used to
|
383 |
+
apply ACL to the query, especially for concepts.
|
384 |
+
|
385 |
+
Returns
|
386 |
+
A SelectRowsResult iterator with rows of `Item`s.
|
387 |
+
"""
|
388 |
+
pass
|
389 |
+
|
390 |
+
@abc.abstractmethod
|
391 |
+
def select_rows_schema(self,
|
392 |
+
columns: Optional[Sequence[ColumnId]] = None,
|
393 |
+
sort_by: Optional[Sequence[Path]] = None,
|
394 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC,
|
395 |
+
searches: Optional[Sequence[Search]] = None,
|
396 |
+
combine_columns: bool = False) -> SelectRowsSchemaResult:
|
397 |
+
"""Returns the schema of the result of `select_rows` above with the same arguments."""
|
398 |
+
pass
|
399 |
+
|
400 |
+
@abc.abstractmethod
|
401 |
+
def stats(self, leaf_path: Path) -> StatsResult:
|
402 |
+
"""Compute stats for a leaf path.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
leaf_path: The leaf path to compute stats for.
|
406 |
+
|
407 |
+
Returns
|
408 |
+
A StatsResult.
|
409 |
+
"""
|
410 |
+
pass
|
411 |
+
|
412 |
+
@abc.abstractmethod
|
413 |
+
def media(self, item_id: str, leaf_path: Path) -> MediaResult:
|
414 |
+
"""Return the media for a leaf path.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
item_id: The item id to get media for.
|
418 |
+
leaf_path: The leaf path for the media.
|
419 |
+
|
420 |
+
Returns
|
421 |
+
A MediaResult.
|
422 |
+
"""
|
423 |
+
pass
|
424 |
+
|
425 |
+
@abc.abstractmethod
|
426 |
+
def to_json(self, filepath: Union[str, pathlib.Path], jsonl: bool = True) -> None:
|
427 |
+
"""Export the dataset to a JSON file.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
filepath: The path to the file to export to.
|
431 |
+
jsonl: Whether to export to JSONL or JSON.
|
432 |
+
"""
|
433 |
+
pass
|
434 |
+
|
435 |
+
@abc.abstractmethod
|
436 |
+
def to_pandas(self) -> pd.DataFrame:
|
437 |
+
"""Export the dataset to a pandas DataFrame."""
|
438 |
+
pass
|
439 |
+
|
440 |
+
@abc.abstractmethod
|
441 |
+
def to_parquet(self, filepath: Union[str, pathlib.Path]) -> None:
|
442 |
+
"""Export the dataset to a parquet file.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
filepath: The path to the file to export to.
|
446 |
+
"""
|
447 |
+
pass
|
448 |
+
|
449 |
+
@abc.abstractmethod
|
450 |
+
def to_csv(self, filepath: Union[str, pathlib.Path]) -> None:
|
451 |
+
"""Export the dataset to a csv file.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
filepath: The path to the file to export to.
|
455 |
+
"""
|
456 |
+
pass
|
457 |
+
|
458 |
+
|
459 |
+
def default_settings(dataset: Dataset) -> DatasetSettings:
|
460 |
+
"""Gets the default settings for a dataset."""
|
461 |
+
schema = dataset.manifest().data_schema
|
462 |
+
leaf_paths = [path for path, field in schema.leafs.items() if field.dtype == DataType.STRING]
|
463 |
+
pool = ThreadPoolExecutor()
|
464 |
+
stats: list[StatsResult] = list(pool.map(lambda leaf: dataset.stats(leaf), leaf_paths))
|
465 |
+
sorted_stats = sorted([stat for stat in stats if stat.avg_text_length],
|
466 |
+
key=lambda stat: stat.avg_text_length or -1.0)
|
467 |
+
media_paths: set[PathTuple] = set()
|
468 |
+
if sorted_stats:
|
469 |
+
media_paths = set([sorted_stats[-1].path])
|
470 |
+
|
471 |
+
return DatasetSettings(ui=DatasetUISettings(media_paths=media_paths))
|
472 |
+
|
473 |
+
|
474 |
+
def make_parquet_id(signal: Signal,
|
475 |
+
source_path: PathTuple,
|
476 |
+
is_computed_signal: Optional[bool] = False) -> str:
|
477 |
+
"""Return a unique identifier for this parquet table."""
|
478 |
+
# Don't use the VALUE_KEY as part of the parquet id to reduce the size of paths.
|
479 |
+
path = source_path[:-1] if source_path[-1] == VALUE_KEY else source_path
|
480 |
+
column_alias = '.'.join(map(str, path))
|
481 |
+
if column_alias.endswith('.*'):
|
482 |
+
# Remove the trailing .* from the column name.
|
483 |
+
column_alias = column_alias[:-2]
|
484 |
+
|
485 |
+
return f'{signal.key(is_computed_signal=is_computed_signal)}({column_alias})'
|
lilac/data/dataset_duckdb.py
ADDED
@@ -0,0 +1,1717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The DuckDB implementation of the dataset database."""
|
2 |
+
import functools
|
3 |
+
import gc
|
4 |
+
import glob
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import re
|
9 |
+
import shutil
|
10 |
+
import threading
|
11 |
+
from typing import Any, Iterable, Iterator, Optional, Sequence, Union, cast
|
12 |
+
|
13 |
+
import duckdb
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from pandas.api.types import is_object_dtype
|
17 |
+
from pydantic import BaseModel, validator
|
18 |
+
from typing_extensions import override
|
19 |
+
|
20 |
+
from ..auth import UserInfo
|
21 |
+
from ..batch_utils import deep_flatten, deep_unflatten
|
22 |
+
from ..embeddings.vector_store import VectorDBIndex
|
23 |
+
from ..env import data_path, env
|
24 |
+
from ..schema import (
|
25 |
+
MANIFEST_FILENAME,
|
26 |
+
PATH_WILDCARD,
|
27 |
+
TEXT_SPAN_END_FEATURE,
|
28 |
+
TEXT_SPAN_START_FEATURE,
|
29 |
+
UUID_COLUMN,
|
30 |
+
VALUE_KEY,
|
31 |
+
Bin,
|
32 |
+
DataType,
|
33 |
+
Field,
|
34 |
+
Item,
|
35 |
+
Path,
|
36 |
+
PathKey,
|
37 |
+
PathTuple,
|
38 |
+
RichData,
|
39 |
+
Schema,
|
40 |
+
SourceManifest,
|
41 |
+
column_paths_match,
|
42 |
+
is_float,
|
43 |
+
is_integer,
|
44 |
+
is_ordinal,
|
45 |
+
is_temporal,
|
46 |
+
normalize_path,
|
47 |
+
signal_type_supports_dtype,
|
48 |
+
)
|
49 |
+
from ..signals.concept_labels import ConceptLabelsSignal
|
50 |
+
from ..signals.concept_scorer import ConceptScoreSignal
|
51 |
+
from ..signals.semantic_similarity import SemanticSimilaritySignal
|
52 |
+
from ..signals.signal import (
|
53 |
+
Signal,
|
54 |
+
TextEmbeddingSignal,
|
55 |
+
VectorSignal,
|
56 |
+
get_signal_by_type,
|
57 |
+
resolve_signal,
|
58 |
+
)
|
59 |
+
from ..signals.substring_search import SubstringSignal
|
60 |
+
from ..tasks import TaskStepId, progress
|
61 |
+
from ..utils import DebugTimer, get_dataset_output_dir, log, open_file
|
62 |
+
from . import dataset
|
63 |
+
from .dataset import (
|
64 |
+
BinaryOp,
|
65 |
+
Column,
|
66 |
+
ColumnId,
|
67 |
+
Dataset,
|
68 |
+
DatasetManifest,
|
69 |
+
DatasetSettings,
|
70 |
+
FeatureListValue,
|
71 |
+
FeatureValue,
|
72 |
+
Filter,
|
73 |
+
FilterLike,
|
74 |
+
GroupsSortBy,
|
75 |
+
ListOp,
|
76 |
+
MediaResult,
|
77 |
+
Search,
|
78 |
+
SearchResultInfo,
|
79 |
+
SelectGroupsResult,
|
80 |
+
SelectRowsResult,
|
81 |
+
SelectRowsSchemaResult,
|
82 |
+
SelectRowsSchemaUDF,
|
83 |
+
SortOrder,
|
84 |
+
SortResult,
|
85 |
+
StatsResult,
|
86 |
+
UnaryOp,
|
87 |
+
column_from_identifier,
|
88 |
+
default_settings,
|
89 |
+
make_parquet_id,
|
90 |
+
)
|
91 |
+
from .dataset_utils import (
|
92 |
+
count_primitives,
|
93 |
+
create_signal_schema,
|
94 |
+
flatten_keys,
|
95 |
+
merge_schemas,
|
96 |
+
schema_contains_path,
|
97 |
+
sparse_to_dense_compute,
|
98 |
+
wrap_in_dicts,
|
99 |
+
write_embeddings_to_disk,
|
100 |
+
write_items_to_parquet,
|
101 |
+
)
|
102 |
+
|
103 |
+
UUID_INDEX_FILENAME = 'uuids.npy'
|
104 |
+
|
105 |
+
SIGNAL_MANIFEST_FILENAME = 'signal_manifest.json'
|
106 |
+
DATASET_SETTINGS_FILENAME = 'settings.json'
|
107 |
+
SOURCE_VIEW_NAME = 'source'
|
108 |
+
|
109 |
+
# Sample size for approximating the distinct count of a column.
|
110 |
+
SAMPLE_SIZE_DISTINCT_COUNT = 100_000
|
111 |
+
NUM_AUTO_BINS = 15
|
112 |
+
|
113 |
+
BINARY_OP_TO_SQL: dict[BinaryOp, str] = {
|
114 |
+
BinaryOp.EQUALS: '=',
|
115 |
+
BinaryOp.NOT_EQUAL: '!=',
|
116 |
+
BinaryOp.GREATER: '>',
|
117 |
+
BinaryOp.GREATER_EQUAL: '>=',
|
118 |
+
BinaryOp.LESS: '<',
|
119 |
+
BinaryOp.LESS_EQUAL: '<='
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
class DuckDBSearchUDF(BaseModel):
|
124 |
+
"""The transformation of searches to column UDFs."""
|
125 |
+
udf: Column
|
126 |
+
search_path: PathTuple
|
127 |
+
output_path: PathTuple
|
128 |
+
sort: Optional[tuple[PathTuple, SortOrder]] = None
|
129 |
+
|
130 |
+
|
131 |
+
class DuckDBSearchUDFs(BaseModel):
|
132 |
+
"""The transformation of searches to column UDFs with sorts."""
|
133 |
+
udfs: list[Column]
|
134 |
+
output_paths: list[PathTuple]
|
135 |
+
sorts: list[tuple[PathTuple, SortOrder]]
|
136 |
+
|
137 |
+
|
138 |
+
class DatasetDuckDB(Dataset):
|
139 |
+
"""The DuckDB implementation of the dataset database."""
|
140 |
+
|
141 |
+
def __init__(self, namespace: str, dataset_name: str, vector_store: str = 'hnsw'):
|
142 |
+
super().__init__(namespace, dataset_name)
|
143 |
+
|
144 |
+
self.dataset_path = get_dataset_output_dir(data_path(), namespace, dataset_name)
|
145 |
+
|
146 |
+
# TODO: Infer the manifest from the parquet files so this is lighter weight.
|
147 |
+
self._source_manifest = read_source_manifest(self.dataset_path)
|
148 |
+
self._signal_manifests: list[SignalManifest] = []
|
149 |
+
self.con = duckdb.connect(database=':memory:')
|
150 |
+
|
151 |
+
# Maps a path and embedding to the vector index. This is lazily generated as needed.
|
152 |
+
self._vector_indices: dict[tuple[PathKey, str], VectorDBIndex] = {}
|
153 |
+
self.vector_store = vector_store
|
154 |
+
self._manifest_lock = threading.Lock()
|
155 |
+
|
156 |
+
# Calling settings creates the default settings JSON file if it doesn't exist.
|
157 |
+
self.settings()
|
158 |
+
|
159 |
+
@override
|
160 |
+
def delete(self) -> None:
|
161 |
+
"""Deletes the dataset."""
|
162 |
+
self.con.close()
|
163 |
+
shutil.rmtree(self.dataset_path, ignore_errors=True)
|
164 |
+
|
165 |
+
def _create_view(self, view_name: str, files: list[str]) -> None:
|
166 |
+
self.con.execute(f"""
|
167 |
+
CREATE OR REPLACE VIEW {_escape_col_name(view_name)} AS (SELECT * FROM read_parquet({files}));
|
168 |
+
""")
|
169 |
+
|
170 |
+
# NOTE: This is cached, but when the latest mtime of any file in the dataset directory changes
|
171 |
+
# the results are invalidated.
|
172 |
+
@functools.cache
|
173 |
+
def _recompute_joint_table(self, latest_mtime_micro_sec: int) -> DatasetManifest:
|
174 |
+
del latest_mtime_micro_sec # This is used as the cache key.
|
175 |
+
merged_schema = self._source_manifest.data_schema.copy(deep=True)
|
176 |
+
self._signal_manifests = []
|
177 |
+
# Make a joined view of all the column groups.
|
178 |
+
self._create_view(SOURCE_VIEW_NAME,
|
179 |
+
[os.path.join(self.dataset_path, f) for f in self._source_manifest.files])
|
180 |
+
|
181 |
+
# Add the signal column groups.
|
182 |
+
for root, _, files in os.walk(self.dataset_path):
|
183 |
+
for file in files:
|
184 |
+
if not file.endswith(SIGNAL_MANIFEST_FILENAME):
|
185 |
+
continue
|
186 |
+
|
187 |
+
with open_file(os.path.join(root, file)) as f:
|
188 |
+
signal_manifest = SignalManifest.parse_raw(f.read())
|
189 |
+
self._signal_manifests.append(signal_manifest)
|
190 |
+
signal_files = [os.path.join(root, f) for f in signal_manifest.files]
|
191 |
+
if signal_files:
|
192 |
+
self._create_view(signal_manifest.parquet_id, signal_files)
|
193 |
+
|
194 |
+
merged_schema = merge_schemas([self._source_manifest.data_schema] +
|
195 |
+
[m.data_schema for m in self._signal_manifests])
|
196 |
+
|
197 |
+
# The logic below generates the following example query:
|
198 |
+
# CREATE OR REPLACE VIEW t AS (
|
199 |
+
# SELECT
|
200 |
+
# source.*,
|
201 |
+
# "parquet_id1"."root_column" AS "parquet_id1",
|
202 |
+
# "parquet_id2"."root_column" AS "parquet_id2"
|
203 |
+
# FROM source JOIN "parquet_id1" USING (uuid,) JOIN "parquet_id2" USING (uuid,)
|
204 |
+
# );
|
205 |
+
# NOTE: "root_column" for each signal is defined as the top-level column.
|
206 |
+
select_sql = ', '.join([f'{SOURCE_VIEW_NAME}.*'] + [(
|
207 |
+
f'{_escape_col_name(manifest.parquet_id)}.{_escape_col_name(_root_column(manifest))} '
|
208 |
+
f'AS {_escape_col_name(manifest.parquet_id)}')
|
209 |
+
for manifest in self._signal_manifests
|
210 |
+
if manifest.files])
|
211 |
+
join_sql = ' '.join([SOURCE_VIEW_NAME] + [
|
212 |
+
f'join {_escape_col_name(manifest.parquet_id)} using ({UUID_COLUMN},)'
|
213 |
+
for manifest in self._signal_manifests
|
214 |
+
if manifest.files
|
215 |
+
])
|
216 |
+
view_or_table = 'TABLE'
|
217 |
+
use_views = env('DUCKDB_USE_VIEWS', 0) or 0
|
218 |
+
if int(use_views):
|
219 |
+
view_or_table = 'VIEW'
|
220 |
+
sql_cmd = f"""CREATE OR REPLACE {view_or_table} t AS (SELECT {select_sql} FROM {join_sql})"""
|
221 |
+
self.con.execute(sql_cmd)
|
222 |
+
|
223 |
+
# Get the total size of the table.
|
224 |
+
size_query = 'SELECT COUNT() as count FROM t'
|
225 |
+
size_query_result = cast(Any, self._query(size_query)[0])
|
226 |
+
num_items = cast(int, size_query_result[0])
|
227 |
+
|
228 |
+
return DatasetManifest(
|
229 |
+
namespace=self.namespace,
|
230 |
+
dataset_name=self.dataset_name,
|
231 |
+
data_schema=merged_schema,
|
232 |
+
num_items=num_items)
|
233 |
+
|
234 |
+
@override
|
235 |
+
def manifest(self) -> DatasetManifest:
|
236 |
+
# Use the latest modification time of all files under the dataset path as the cache key for
|
237 |
+
# re-computing the manifest and the joined view.
|
238 |
+
with self._manifest_lock:
|
239 |
+
all_dataset_files = glob.iglob(os.path.join(self.dataset_path, '**'), recursive=True)
|
240 |
+
latest_mtime = max(map(os.path.getmtime, all_dataset_files))
|
241 |
+
latest_mtime_micro_sec = int(latest_mtime * 1e6)
|
242 |
+
return self._recompute_joint_table(latest_mtime_micro_sec)
|
243 |
+
|
244 |
+
@override
|
245 |
+
def settings(self) -> DatasetSettings:
|
246 |
+
# Read the settings file from disk.
|
247 |
+
settings_filepath = _settings_filepath(self.namespace, self.dataset_name)
|
248 |
+
if not os.path.exists(settings_filepath):
|
249 |
+
self.update_settings(default_settings(self))
|
250 |
+
|
251 |
+
with open(settings_filepath) as f:
|
252 |
+
return DatasetSettings.parse_raw(f.read())
|
253 |
+
|
254 |
+
@override
|
255 |
+
def update_settings(self, settings: DatasetSettings) -> None:
|
256 |
+
# Write the settings file from disk.
|
257 |
+
settings_filepath = _settings_filepath(self.namespace, self.dataset_name)
|
258 |
+
with open(settings_filepath, 'w') as f:
|
259 |
+
f.write(settings.json())
|
260 |
+
|
261 |
+
def count(self, filters: Optional[list[FilterLike]] = None) -> int:
|
262 |
+
"""Count the number of rows."""
|
263 |
+
raise NotImplementedError('count is not yet implemented for DuckDB.')
|
264 |
+
|
265 |
+
def _get_vector_db_index(self, embedding: str, path: PathTuple) -> VectorDBIndex:
|
266 |
+
# Refresh the manifest to make sure we have the latest signal manifests.
|
267 |
+
self.manifest()
|
268 |
+
index_key = (path, embedding)
|
269 |
+
if index_key in self._vector_indices:
|
270 |
+
return self._vector_indices[index_key]
|
271 |
+
|
272 |
+
manifests = [
|
273 |
+
m for m in self._signal_manifests
|
274 |
+
if schema_contains_path(m.data_schema, path) and m.vector_store and m.signal.name == embedding
|
275 |
+
]
|
276 |
+
if not manifests:
|
277 |
+
raise ValueError(f'No embedding found for path {path}.')
|
278 |
+
if len(manifests) > 1:
|
279 |
+
raise ValueError(f'Multiple embeddings found for path {path}. Got: {manifests}')
|
280 |
+
manifest = manifests[0]
|
281 |
+
if not manifest.vector_store:
|
282 |
+
raise ValueError(f'Signal manifest for path {path} is not an embedding. '
|
283 |
+
f'Got signal manifest: {manifest}')
|
284 |
+
|
285 |
+
base_path = os.path.join(self.dataset_path, _signal_dir(manifest.enriched_path),
|
286 |
+
manifest.signal.name)
|
287 |
+
with DebugTimer(f'Loading vector store "{manifest.vector_store}" for "{path}"'
|
288 |
+
f' with embedding "{embedding}"'):
|
289 |
+
vector_index = VectorDBIndex(manifest.vector_store)
|
290 |
+
vector_index.load(base_path)
|
291 |
+
# Cache the vector index.
|
292 |
+
self._vector_indices[index_key] = vector_index
|
293 |
+
return vector_index
|
294 |
+
|
295 |
+
@override
|
296 |
+
def compute_signal(self,
|
297 |
+
signal: Signal,
|
298 |
+
leaf_path: Path,
|
299 |
+
task_step_id: Optional[TaskStepId] = None) -> None:
|
300 |
+
if isinstance(signal, TextEmbeddingSignal):
|
301 |
+
return self.compute_embedding(signal.name, leaf_path, task_step_id)
|
302 |
+
source_path = normalize_path(leaf_path)
|
303 |
+
manifest = self.manifest()
|
304 |
+
|
305 |
+
if task_step_id is None:
|
306 |
+
# Make a dummy task step so we report progress via tqdm.
|
307 |
+
task_step_id = ('', 0)
|
308 |
+
|
309 |
+
# The manifest may have changed after computing the dependencies.
|
310 |
+
manifest = self.manifest()
|
311 |
+
|
312 |
+
signal_col = Column(path=source_path, alias='value', signal_udf=signal)
|
313 |
+
select_rows_result = self.select_rows([signal_col],
|
314 |
+
task_step_id=task_step_id,
|
315 |
+
resolve_span=True)
|
316 |
+
df = select_rows_result.df()
|
317 |
+
values = df['value']
|
318 |
+
|
319 |
+
enriched_path = _col_destination_path(signal_col, is_computed_signal=True)
|
320 |
+
spec = _split_path_into_subpaths_of_lists(enriched_path)
|
321 |
+
output_dir = os.path.join(self.dataset_path, _signal_dir(enriched_path))
|
322 |
+
signal_schema = create_signal_schema(signal, source_path, manifest.data_schema)
|
323 |
+
enriched_signal_items = cast(Iterable[Item], wrap_in_dicts(values, spec))
|
324 |
+
for uuid, item in zip(df[UUID_COLUMN], enriched_signal_items):
|
325 |
+
item[UUID_COLUMN] = uuid
|
326 |
+
|
327 |
+
enriched_signal_items = list(enriched_signal_items)
|
328 |
+
parquet_filename, _ = write_items_to_parquet(
|
329 |
+
items=enriched_signal_items,
|
330 |
+
output_dir=output_dir,
|
331 |
+
schema=signal_schema,
|
332 |
+
filename_prefix='data',
|
333 |
+
shard_index=0,
|
334 |
+
num_shards=1)
|
335 |
+
|
336 |
+
signal_manifest = SignalManifest(
|
337 |
+
files=[parquet_filename],
|
338 |
+
data_schema=signal_schema,
|
339 |
+
signal=signal,
|
340 |
+
enriched_path=source_path,
|
341 |
+
parquet_id=make_parquet_id(signal, source_path, is_computed_signal=True))
|
342 |
+
signal_manifest_filepath = os.path.join(output_dir, SIGNAL_MANIFEST_FILENAME)
|
343 |
+
with open_file(signal_manifest_filepath, 'w') as f:
|
344 |
+
f.write(signal_manifest.json(exclude_none=True, indent=2))
|
345 |
+
log(f'Wrote signal output to {output_dir}')
|
346 |
+
|
347 |
+
@override
|
348 |
+
def compute_embedding(self,
|
349 |
+
embedding: str,
|
350 |
+
path: Path,
|
351 |
+
task_step_id: Optional[TaskStepId] = None) -> None:
|
352 |
+
source_path = normalize_path(path)
|
353 |
+
manifest = self.manifest()
|
354 |
+
|
355 |
+
if task_step_id is None:
|
356 |
+
# Make a dummy task step so we report progress via tqdm.
|
357 |
+
task_step_id = ('', 0)
|
358 |
+
|
359 |
+
signal = get_signal_by_type(embedding, TextEmbeddingSignal)()
|
360 |
+
signal_col = Column(path=source_path, alias='value', signal_udf=signal)
|
361 |
+
select_rows_result = self.select_rows([signal_col],
|
362 |
+
task_step_id=task_step_id,
|
363 |
+
resolve_span=True)
|
364 |
+
df = select_rows_result.df()
|
365 |
+
values = df['value']
|
366 |
+
|
367 |
+
enriched_path = _col_destination_path(signal_col, is_computed_signal=True)
|
368 |
+
output_dir = os.path.join(self.dataset_path, _signal_dir(enriched_path))
|
369 |
+
signal_schema = create_signal_schema(signal, source_path, manifest.data_schema)
|
370 |
+
|
371 |
+
write_embeddings_to_disk(
|
372 |
+
vector_store=self.vector_store,
|
373 |
+
uuids=df[UUID_COLUMN],
|
374 |
+
signal_items=values,
|
375 |
+
output_dir=output_dir)
|
376 |
+
|
377 |
+
del select_rows_result, df, values
|
378 |
+
gc.collect()
|
379 |
+
|
380 |
+
signal_manifest = SignalManifest(
|
381 |
+
files=[],
|
382 |
+
data_schema=signal_schema,
|
383 |
+
signal=signal,
|
384 |
+
enriched_path=source_path,
|
385 |
+
parquet_id=make_parquet_id(signal, source_path, is_computed_signal=True),
|
386 |
+
vector_store=self.vector_store)
|
387 |
+
signal_manifest_filepath = os.path.join(output_dir, SIGNAL_MANIFEST_FILENAME)
|
388 |
+
|
389 |
+
with open_file(signal_manifest_filepath, 'w') as f:
|
390 |
+
f.write(signal_manifest.json(exclude_none=True, indent=2))
|
391 |
+
log(f'Wrote embedding index to {output_dir}')
|
392 |
+
|
393 |
+
@override
|
394 |
+
def delete_signal(self, signal_path: Path) -> None:
|
395 |
+
signal_path = normalize_path(signal_path)
|
396 |
+
manifest = self.manifest()
|
397 |
+
if not manifest.data_schema.has_field(signal_path):
|
398 |
+
raise ValueError(f'Unknown signal path: {signal_path}')
|
399 |
+
|
400 |
+
output_dir = os.path.join(self.dataset_path, _signal_dir(signal_path))
|
401 |
+
shutil.rmtree(output_dir, ignore_errors=True)
|
402 |
+
|
403 |
+
def _validate_filters(self, filters: Sequence[Filter], col_aliases: dict[str, PathTuple],
|
404 |
+
manifest: DatasetManifest) -> None:
|
405 |
+
for filter in filters:
|
406 |
+
if filter.path[0] in col_aliases:
|
407 |
+
# This is a filter on a column alias, which is always allowed.
|
408 |
+
continue
|
409 |
+
|
410 |
+
current_field = Field(fields=manifest.data_schema.fields)
|
411 |
+
for path_part in filter.path:
|
412 |
+
if path_part == VALUE_KEY:
|
413 |
+
if not current_field.dtype:
|
414 |
+
raise ValueError(f'Unable to filter on path {filter.path}. The field has no value.')
|
415 |
+
continue
|
416 |
+
if current_field.fields:
|
417 |
+
if path_part not in current_field.fields:
|
418 |
+
raise ValueError(f'Unable to filter on path {filter.path}. '
|
419 |
+
f'Path part "{path_part}" not found in the dataset.')
|
420 |
+
current_field = current_field.fields[str(path_part)]
|
421 |
+
continue
|
422 |
+
elif current_field.repeated_field:
|
423 |
+
current_field = current_field.repeated_field
|
424 |
+
continue
|
425 |
+
else:
|
426 |
+
raise ValueError(f'Unable to filter on path {filter.path}. '
|
427 |
+
f'Path part "{path_part}" is not defined on a primitive value.')
|
428 |
+
|
429 |
+
while current_field.repeated_field:
|
430 |
+
current_field = current_field.repeated_field
|
431 |
+
filter.path = (*filter.path, PATH_WILDCARD)
|
432 |
+
|
433 |
+
if not current_field.dtype:
|
434 |
+
raise ValueError(f'Unable to filter on path {filter.path}. The field has no value.')
|
435 |
+
|
436 |
+
def _validate_udfs(self, udf_cols: Sequence[Column], source_schema: Schema) -> None:
|
437 |
+
for col in udf_cols:
|
438 |
+
path = col.path
|
439 |
+
|
440 |
+
# Signal transforms must operate on a leaf field.
|
441 |
+
leaf = source_schema.leafs.get(path)
|
442 |
+
if not leaf or not leaf.dtype:
|
443 |
+
raise ValueError(f'Leaf "{path}" not found in dataset. '
|
444 |
+
'Signal transforms must operate on a leaf field.')
|
445 |
+
|
446 |
+
# Signal transforms must have the same dtype as the leaf field.
|
447 |
+
signal = cast(Signal, col.signal_udf)
|
448 |
+
if not signal_type_supports_dtype(signal.input_type, leaf.dtype):
|
449 |
+
raise ValueError(f'Leaf "{path}" has dtype "{leaf.dtype}" which is not supported '
|
450 |
+
f'by "{signal.key()}" with signal input type "{signal.input_type}".')
|
451 |
+
|
452 |
+
def _validate_selection(self, columns: Sequence[Column], select_schema: Schema) -> None:
|
453 |
+
# Validate all the columns and make sure they exist in the `select_schema`.
|
454 |
+
for column in columns:
|
455 |
+
current_field = Field(fields=select_schema.fields)
|
456 |
+
path = column.path
|
457 |
+
for path_part in path:
|
458 |
+
if path_part == VALUE_KEY:
|
459 |
+
if not current_field.dtype:
|
460 |
+
raise ValueError(f'Unable to select path {path}. The field that has no value.')
|
461 |
+
continue
|
462 |
+
if current_field.fields:
|
463 |
+
if path_part not in current_field.fields:
|
464 |
+
raise ValueError(f'Unable to select path {path}. '
|
465 |
+
f'Path part "{path_part}" not found in the dataset.')
|
466 |
+
current_field = current_field.fields[path_part]
|
467 |
+
continue
|
468 |
+
elif current_field.repeated_field:
|
469 |
+
if path_part.isdigit():
|
470 |
+
raise ValueError(f'Unable to select path {path}. Selecting a specific index of '
|
471 |
+
'a repeated field is currently not supported.')
|
472 |
+
if path_part != PATH_WILDCARD:
|
473 |
+
raise ValueError(f'Unable to select path {path}. '
|
474 |
+
f'Path part "{path_part}" should be a wildcard.')
|
475 |
+
current_field = current_field.repeated_field
|
476 |
+
elif not current_field.dtype:
|
477 |
+
raise ValueError(f'Unable to select path {path}. '
|
478 |
+
f'Path part "{path_part}" is not defined on a primitive value.')
|
479 |
+
|
480 |
+
def _validate_columns(self, columns: Sequence[Column], source_schema: Schema,
|
481 |
+
select_schema: Schema) -> None:
|
482 |
+
udf_cols = [col for col in columns if col.signal_udf]
|
483 |
+
self._validate_udfs(udf_cols, source_schema)
|
484 |
+
self._validate_selection(columns, select_schema)
|
485 |
+
|
486 |
+
def _validate_sort_path(self, path: PathTuple, schema: Schema) -> None:
|
487 |
+
current_field = Field(fields=schema.fields)
|
488 |
+
for path_part in path:
|
489 |
+
if path_part == VALUE_KEY:
|
490 |
+
if not current_field.dtype:
|
491 |
+
raise ValueError(f'Unable to sort by path {path}. The field that has no value.')
|
492 |
+
continue
|
493 |
+
if current_field.fields:
|
494 |
+
if path_part not in current_field.fields:
|
495 |
+
raise ValueError(f'Unable to sort by path {path}. '
|
496 |
+
f'Path part "{path_part}" not found in the dataset.')
|
497 |
+
current_field = current_field.fields[path_part]
|
498 |
+
continue
|
499 |
+
elif current_field.repeated_field:
|
500 |
+
if path_part.isdigit():
|
501 |
+
raise ValueError(f'Unable to sort by path {path}. Selecting a specific index of '
|
502 |
+
'a repeated field is currently not supported.')
|
503 |
+
if path_part != PATH_WILDCARD:
|
504 |
+
raise ValueError(f'Unable to sort by path {path}. '
|
505 |
+
f'Path part "{path_part}" should be a wildcard.')
|
506 |
+
current_field = current_field.repeated_field
|
507 |
+
elif not current_field.dtype:
|
508 |
+
raise ValueError(f'Unable to sort by path {path}. '
|
509 |
+
f'Path part "{path_part}" is not defined on a primitive value.')
|
510 |
+
if not current_field.dtype:
|
511 |
+
raise ValueError(f'Unable to sort by path {path}. The field has no value.')
|
512 |
+
|
513 |
+
@override
|
514 |
+
def stats(self, leaf_path: Path) -> StatsResult:
|
515 |
+
if not leaf_path:
|
516 |
+
raise ValueError('leaf_path must be provided')
|
517 |
+
path = normalize_path(leaf_path)
|
518 |
+
manifest = self.manifest()
|
519 |
+
leaf = manifest.data_schema.get_field(path)
|
520 |
+
# Find the inner-most leaf in case this field is repeated.
|
521 |
+
while leaf.repeated_field:
|
522 |
+
leaf = leaf.repeated_field
|
523 |
+
path = (*path, PATH_WILDCARD)
|
524 |
+
|
525 |
+
if not leaf.dtype:
|
526 |
+
raise ValueError(f'Leaf "{path}" not found in dataset')
|
527 |
+
|
528 |
+
duckdb_path = self._leaf_path_to_duckdb_path(path, manifest.data_schema)
|
529 |
+
inner_select = _select_sql(
|
530 |
+
duckdb_path, flatten=True, unnest=True, span_from=self._get_span_from(path, manifest))
|
531 |
+
|
532 |
+
# Compute approximate count by sampling the data to avoid OOM.
|
533 |
+
sample_size = SAMPLE_SIZE_DISTINCT_COUNT
|
534 |
+
avg_length_query = ''
|
535 |
+
if leaf.dtype == DataType.STRING:
|
536 |
+
avg_length_query = ', avg(length(val)) as avgTextLength'
|
537 |
+
|
538 |
+
row: Optional[tuple[int, ...]] = None
|
539 |
+
if leaf.dtype == DataType.BOOLEAN:
|
540 |
+
approx_count_distinct = 2
|
541 |
+
else:
|
542 |
+
approx_count_query = f"""
|
543 |
+
SELECT approx_count_distinct(val) as approxCountDistinct {avg_length_query}
|
544 |
+
FROM (SELECT {inner_select} AS val FROM t LIMIT {sample_size});
|
545 |
+
"""
|
546 |
+
row = self._query(approx_count_query)[0]
|
547 |
+
approx_count_distinct = row[0]
|
548 |
+
|
549 |
+
total_count_query = f'SELECT count(val) FROM (SELECT {inner_select} as val FROM t)'
|
550 |
+
total_count = self._query(total_count_query)[0][0]
|
551 |
+
|
552 |
+
if leaf.dtype != DataType.BOOLEAN:
|
553 |
+
# Adjust the counts for the sample size.
|
554 |
+
factor = max(1, total_count / sample_size)
|
555 |
+
approx_count_distinct = round(approx_count_distinct * factor)
|
556 |
+
|
557 |
+
result = StatsResult(
|
558 |
+
path=path, total_count=total_count, approx_count_distinct=approx_count_distinct)
|
559 |
+
|
560 |
+
if leaf.dtype == DataType.STRING and row:
|
561 |
+
result.avg_text_length = row[1]
|
562 |
+
|
563 |
+
# Compute min/max values for ordinal leafs, without sampling the data.
|
564 |
+
if is_ordinal(leaf.dtype):
|
565 |
+
min_max_query = f"""
|
566 |
+
SELECT MIN(val) AS minVal, MAX(val) AS maxVal
|
567 |
+
FROM (SELECT {inner_select} as val FROM t)
|
568 |
+
{'WHERE NOT isnan(val)' if is_float(leaf.dtype) else ''}
|
569 |
+
"""
|
570 |
+
row = self._query(min_max_query)[0]
|
571 |
+
result.min_val, result.max_val = row
|
572 |
+
|
573 |
+
return result
|
574 |
+
|
575 |
+
@override
|
576 |
+
def select_groups(
|
577 |
+
self,
|
578 |
+
leaf_path: Path,
|
579 |
+
filters: Optional[Sequence[FilterLike]] = None,
|
580 |
+
sort_by: Optional[GroupsSortBy] = GroupsSortBy.COUNT,
|
581 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC,
|
582 |
+
limit: Optional[int] = None,
|
583 |
+
bins: Optional[Union[Sequence[Bin], Sequence[float]]] = None) -> SelectGroupsResult:
|
584 |
+
if not leaf_path:
|
585 |
+
raise ValueError('leaf_path must be provided')
|
586 |
+
path = normalize_path(leaf_path)
|
587 |
+
manifest = self.manifest()
|
588 |
+
leaf = manifest.data_schema.get_field(path)
|
589 |
+
# Find the inner-most leaf in case this field is repeated.
|
590 |
+
while leaf.repeated_field:
|
591 |
+
leaf = leaf.repeated_field
|
592 |
+
path = (*path, PATH_WILDCARD)
|
593 |
+
|
594 |
+
if not leaf.dtype:
|
595 |
+
raise ValueError(f'Leaf "{path}" not found in dataset')
|
596 |
+
|
597 |
+
inner_val = 'inner_val'
|
598 |
+
outer_select = inner_val
|
599 |
+
# Normalize the bins to be `list[Bin]`.
|
600 |
+
named_bins = _normalize_bins(bins or leaf.bins)
|
601 |
+
stats = self.stats(leaf_path)
|
602 |
+
|
603 |
+
leaf_is_float = is_float(leaf.dtype)
|
604 |
+
leaf_is_integer = is_integer(leaf.dtype)
|
605 |
+
if not leaf.categorical and (leaf_is_float or leaf_is_integer):
|
606 |
+
if named_bins is None:
|
607 |
+
# Auto-bin.
|
608 |
+
named_bins = _auto_bins(stats, NUM_AUTO_BINS)
|
609 |
+
|
610 |
+
sql_bounds = []
|
611 |
+
for label, start, end in named_bins:
|
612 |
+
if start is None:
|
613 |
+
start = cast(float, "'-Infinity'")
|
614 |
+
if end is None:
|
615 |
+
end = cast(float, "'Infinity'")
|
616 |
+
sql_bounds.append(f"('{label}', {start}, {end})")
|
617 |
+
|
618 |
+
bin_index_col = 'col0'
|
619 |
+
bin_min_col = 'col1'
|
620 |
+
bin_max_col = 'col2'
|
621 |
+
is_nan_filter = f'NOT isnan({inner_val}) AND' if leaf_is_float else ''
|
622 |
+
|
623 |
+
# We cast the field to `double` so binning works for both `float` and `int` fields.
|
624 |
+
outer_select = f"""(
|
625 |
+
SELECT {bin_index_col} FROM (
|
626 |
+
VALUES {', '.join(sql_bounds)}
|
627 |
+
) WHERE {is_nan_filter}
|
628 |
+
{inner_val}::DOUBLE >= {bin_min_col} AND {inner_val}::DOUBLE < {bin_max_col}
|
629 |
+
)"""
|
630 |
+
else:
|
631 |
+
if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
|
632 |
+
return SelectGroupsResult(too_many_distinct=True, counts=[], bins=named_bins)
|
633 |
+
|
634 |
+
count_column = 'count'
|
635 |
+
value_column = 'value'
|
636 |
+
|
637 |
+
limit_query = f'LIMIT {limit}' if limit else ''
|
638 |
+
duckdb_path = self._leaf_path_to_duckdb_path(path, manifest.data_schema)
|
639 |
+
inner_select = _select_sql(
|
640 |
+
duckdb_path, flatten=True, unnest=True, span_from=self._get_span_from(path, manifest))
|
641 |
+
|
642 |
+
filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
|
643 |
+
filter_queries = self._create_where(manifest, filters, searches=[])
|
644 |
+
|
645 |
+
where_query = ''
|
646 |
+
if filter_queries:
|
647 |
+
where_query = f"WHERE {' AND '.join(filter_queries)}"
|
648 |
+
|
649 |
+
query = f"""
|
650 |
+
SELECT {outer_select} AS {value_column}, COUNT() AS {count_column}
|
651 |
+
FROM (SELECT {inner_select} AS {inner_val} FROM t {where_query})
|
652 |
+
GROUP BY {value_column}
|
653 |
+
ORDER BY {sort_by} {sort_order}
|
654 |
+
{limit_query}
|
655 |
+
"""
|
656 |
+
df = self._query_df(query)
|
657 |
+
counts = list(df.itertuples(index=False, name=None))
|
658 |
+
if is_temporal(leaf.dtype):
|
659 |
+
# Replace any NaT with None and pd.Timestamp to native datetime objects.
|
660 |
+
counts = [(None if pd.isnull(val) else val.to_pydatetime(), count) for val, count in counts]
|
661 |
+
return SelectGroupsResult(too_many_distinct=False, counts=counts, bins=named_bins)
|
662 |
+
|
663 |
+
def _topk_udf_to_sort_by(
|
664 |
+
self,
|
665 |
+
udf_columns: list[Column],
|
666 |
+
sort_by: list[PathTuple],
|
667 |
+
limit: Optional[int],
|
668 |
+
sort_order: Optional[SortOrder],
|
669 |
+
) -> Optional[Column]:
|
670 |
+
if (sort_order != SortOrder.DESC) or (not limit) or (not sort_by):
|
671 |
+
return None
|
672 |
+
if len(sort_by) < 1:
|
673 |
+
return None
|
674 |
+
primary_sort_by = sort_by[0]
|
675 |
+
udf_cols_to_sort_by = [
|
676 |
+
udf_col for udf_col in udf_columns if udf_col.alias == primary_sort_by[0] or
|
677 |
+
_path_contains(_col_destination_path(udf_col), primary_sort_by)
|
678 |
+
]
|
679 |
+
if not udf_cols_to_sort_by:
|
680 |
+
return None
|
681 |
+
udf_col = udf_cols_to_sort_by[0]
|
682 |
+
if udf_col.signal_udf and not isinstance(udf_col.signal_udf, VectorSignal):
|
683 |
+
return None
|
684 |
+
return udf_col
|
685 |
+
|
686 |
+
def _normalize_columns(self, columns: Optional[Sequence[ColumnId]],
|
687 |
+
schema: Schema) -> list[Column]:
|
688 |
+
"""Normalizes the columns to a list of `Column` objects."""
|
689 |
+
cols = [column_from_identifier(col) for col in columns or []]
|
690 |
+
star_in_cols = any(col.path == ('*',) for col in cols)
|
691 |
+
if not cols or star_in_cols:
|
692 |
+
# Select all columns.
|
693 |
+
cols.extend([Column((name,)) for name in schema.fields.keys()])
|
694 |
+
if star_in_cols:
|
695 |
+
cols = [col for col in cols if col.path != ('*',)]
|
696 |
+
return cols
|
697 |
+
|
698 |
+
def _merge_sorts(self, search_udfs: list[DuckDBSearchUDF], sort_by: Optional[Sequence[Path]],
|
699 |
+
sort_order: Optional[SortOrder]) -> list[SortResult]:
|
700 |
+
# True when the user has explicitly sorted by the alias of a search UDF (e.g. in ASC order).
|
701 |
+
is_explicit_search_sort = False
|
702 |
+
for sort_by_path in sort_by or []:
|
703 |
+
for search_udf in search_udfs:
|
704 |
+
if column_paths_match(sort_by_path, search_udf.output_path):
|
705 |
+
is_explicit_search_sort = True
|
706 |
+
break
|
707 |
+
|
708 |
+
sort_results: list[SortResult] = []
|
709 |
+
if sort_by and not is_explicit_search_sort:
|
710 |
+
if not sort_order:
|
711 |
+
raise ValueError('`sort_order` is required when `sort_by` is specified.')
|
712 |
+
# If the user has explicitly set a sort by, and it's not a search UDF alias, override.
|
713 |
+
sort_results = [
|
714 |
+
SortResult(path=normalize_path(sort_by), order=sort_order) for sort_by in sort_by if sort_by
|
715 |
+
]
|
716 |
+
else:
|
717 |
+
search_udfs_with_sort = [search_udf for search_udf in search_udfs if search_udf.sort]
|
718 |
+
if search_udfs_with_sort:
|
719 |
+
# Override the sort by the last search sort order when the user hasn't provided an
|
720 |
+
# explicit sort order.
|
721 |
+
last_search_udf = search_udfs_with_sort[-1]
|
722 |
+
assert last_search_udf.sort, 'Expected search UDFs with sort to have a sort.'
|
723 |
+
udf_sort_path, udf_sort_order = last_search_udf.sort
|
724 |
+
sort_results = [
|
725 |
+
SortResult(
|
726 |
+
path=udf_sort_path,
|
727 |
+
order=sort_order or udf_sort_order,
|
728 |
+
search_index=len(search_udfs_with_sort) - 1)
|
729 |
+
]
|
730 |
+
|
731 |
+
return sort_results
|
732 |
+
|
733 |
+
@override
|
734 |
+
def select_rows(self,
|
735 |
+
columns: Optional[Sequence[ColumnId]] = None,
|
736 |
+
searches: Optional[Sequence[Search]] = None,
|
737 |
+
filters: Optional[Sequence[FilterLike]] = None,
|
738 |
+
sort_by: Optional[Sequence[Path]] = None,
|
739 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC,
|
740 |
+
limit: Optional[int] = None,
|
741 |
+
offset: Optional[int] = 0,
|
742 |
+
task_step_id: Optional[TaskStepId] = None,
|
743 |
+
resolve_span: bool = False,
|
744 |
+
combine_columns: bool = False,
|
745 |
+
user: Optional[UserInfo] = None) -> SelectRowsResult:
|
746 |
+
manifest = self.manifest()
|
747 |
+
cols = self._normalize_columns(columns, manifest.data_schema)
|
748 |
+
|
749 |
+
# Always return the UUID column.
|
750 |
+
col_paths = [col.path for col in cols]
|
751 |
+
if (UUID_COLUMN,) not in col_paths:
|
752 |
+
cols.append(column_from_identifier(UUID_COLUMN))
|
753 |
+
|
754 |
+
schema = manifest.data_schema
|
755 |
+
|
756 |
+
if combine_columns:
|
757 |
+
schema = self.select_rows_schema(
|
758 |
+
columns, sort_by, sort_order, searches, combine_columns=True).data_schema
|
759 |
+
|
760 |
+
self._validate_columns(cols, manifest.data_schema, schema)
|
761 |
+
self._normalize_searches(searches, manifest)
|
762 |
+
search_udfs = self._search_udfs(searches, manifest)
|
763 |
+
cols.extend([search_udf.udf for search_udf in search_udfs])
|
764 |
+
udf_columns = [col for col in cols if col.signal_udf]
|
765 |
+
|
766 |
+
# Set extra information on any concept signals.
|
767 |
+
for udf_col in udf_columns:
|
768 |
+
if isinstance(udf_col.signal_udf, (ConceptScoreSignal, ConceptLabelsSignal)):
|
769 |
+
# Concept are access controlled so we tell it about the user.
|
770 |
+
udf_col.signal_udf.set_user(user)
|
771 |
+
|
772 |
+
# Decide on the exact sorting order.
|
773 |
+
sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
|
774 |
+
sort_by = cast(list[PathTuple],
|
775 |
+
[(sort.alias,) if sort.alias else sort.path for sort in sort_results])
|
776 |
+
# Choose the first sort order as we only support a single sort order for now.
|
777 |
+
sort_order = sort_results[0].order if sort_results else None
|
778 |
+
|
779 |
+
col_aliases: dict[str, PathTuple] = {col.alias: col.path for col in cols if col.alias}
|
780 |
+
udf_aliases: dict[str, PathTuple] = {
|
781 |
+
col.alias: col.path for col in cols if col.signal_udf and col.alias
|
782 |
+
}
|
783 |
+
path_to_udf_col_name: dict[PathTuple, str] = {}
|
784 |
+
for col in cols:
|
785 |
+
if col.signal_udf:
|
786 |
+
alias = col.alias or _unique_alias(col)
|
787 |
+
dest_path = _col_destination_path(col)
|
788 |
+
path_to_udf_col_name[dest_path] = alias
|
789 |
+
|
790 |
+
# Filtering and searching.
|
791 |
+
where_query = ''
|
792 |
+
filters, udf_filters = self._normalize_filters(filters, col_aliases, udf_aliases, manifest)
|
793 |
+
filter_queries = self._create_where(manifest, filters, searches)
|
794 |
+
if filter_queries:
|
795 |
+
where_query = f"WHERE {' AND '.join(filter_queries)}"
|
796 |
+
|
797 |
+
total_num_rows = manifest.num_items
|
798 |
+
con = self.con.cursor()
|
799 |
+
|
800 |
+
topk_udf_col = self._topk_udf_to_sort_by(udf_columns, sort_by, limit, sort_order)
|
801 |
+
if topk_udf_col:
|
802 |
+
path_keys: Optional[list[PathKey]] = None
|
803 |
+
if where_query:
|
804 |
+
# If there are filters, we need to send UUIDs to the top k query.
|
805 |
+
df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
|
806 |
+
total_num_rows = len(df)
|
807 |
+
# Convert UUIDs to path keys.
|
808 |
+
path_keys = [(uuid,) for uuid in df[UUID_COLUMN]]
|
809 |
+
|
810 |
+
if path_keys is not None and len(path_keys) == 0:
|
811 |
+
where_query = 'WHERE false'
|
812 |
+
else:
|
813 |
+
topk_signal = cast(VectorSignal, topk_udf_col.signal_udf)
|
814 |
+
# The input is an embedding.
|
815 |
+
vector_index = self._get_vector_db_index(topk_signal.embedding, topk_udf_col.path)
|
816 |
+
k = (limit or 0) + (offset or 0)
|
817 |
+
with DebugTimer(f'Compute topk on "{topk_udf_col.path}" using embedding '
|
818 |
+
f'"{topk_signal.embedding}" with vector store "{self.vector_store}"'):
|
819 |
+
topk = topk_signal.vector_compute_topk(k, vector_index, path_keys)
|
820 |
+
topk_uuids = list(dict.fromkeys([cast(str, uuid) for (uuid, *_), _ in topk]))
|
821 |
+
# Update the offset to account for the number of unique UUIDs.
|
822 |
+
offset = len(dict.fromkeys([cast(str, uuid) for (uuid, *_), _ in topk[:offset]]))
|
823 |
+
|
824 |
+
# Ignore all the other filters and filter DuckDB results only by the top k UUIDs.
|
825 |
+
uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
|
826 |
+
filter_query = self._create_where(manifest, [uuid_filter])[0]
|
827 |
+
where_query = f'WHERE {filter_query}'
|
828 |
+
|
829 |
+
# Map a final column name to a list of temporary namespaced column names that need to be merged.
|
830 |
+
columns_to_merge: dict[str, dict[str, Column]] = {}
|
831 |
+
temp_column_to_offset_column: dict[str, tuple[str, Field]] = {}
|
832 |
+
select_queries: list[str] = []
|
833 |
+
|
834 |
+
for column in cols:
|
835 |
+
path = column.path
|
836 |
+
# If the signal is vector-based, we don't need to select the actual data, just the uuids
|
837 |
+
# plus an arbitrarily nested array of `None`s`.
|
838 |
+
empty = bool(column.signal_udf and schema.get_field(path).dtype == DataType.EMBEDDING)
|
839 |
+
|
840 |
+
select_sqls: list[str] = []
|
841 |
+
final_col_name = column.alias or _unique_alias(column)
|
842 |
+
if final_col_name not in columns_to_merge:
|
843 |
+
columns_to_merge[final_col_name] = {}
|
844 |
+
|
845 |
+
duckdb_paths = self._column_to_duckdb_paths(column, schema, combine_columns)
|
846 |
+
span_from = self._get_span_from(path, manifest) if resolve_span or column.signal_udf else None
|
847 |
+
|
848 |
+
for parquet_id, duckdb_path in duckdb_paths:
|
849 |
+
sql = _select_sql(
|
850 |
+
duckdb_path, flatten=False, unnest=False, empty=empty, span_from=span_from)
|
851 |
+
temp_column_name = (
|
852 |
+
final_col_name if len(duckdb_paths) == 1 else f'{final_col_name}/{parquet_id}')
|
853 |
+
select_sqls.append(f'{sql} AS {_escape_string_literal(temp_column_name)}')
|
854 |
+
columns_to_merge[final_col_name][temp_column_name] = column
|
855 |
+
|
856 |
+
if column.signal_udf and span_from and _schema_has_spans(column.signal_udf.fields()):
|
857 |
+
sql = _select_sql(duckdb_path, flatten=False, unnest=False, empty=empty, span_from=None)
|
858 |
+
temp_offset_column_name = f'{temp_column_name}/offset'
|
859 |
+
temp_offset_column_name = temp_offset_column_name.replace("'", "\\'")
|
860 |
+
select_sqls.append(f'{sql} AS {_escape_string_literal(temp_offset_column_name)}')
|
861 |
+
temp_column_to_offset_column[temp_column_name] = (temp_offset_column_name,
|
862 |
+
column.signal_udf.fields())
|
863 |
+
|
864 |
+
# `select_sqls` can be empty if this column points to a path that will be created by a UDF.
|
865 |
+
if select_sqls:
|
866 |
+
select_queries.append(', '.join(select_sqls))
|
867 |
+
|
868 |
+
sort_sql_before_udf: list[str] = []
|
869 |
+
sort_sql_after_udf: list[str] = []
|
870 |
+
|
871 |
+
for path in sort_by:
|
872 |
+
# We only allow sorting by nodes with a value.
|
873 |
+
sort_path = path
|
874 |
+
first_subpath = str(path[0])
|
875 |
+
rest_of_path = path[1:]
|
876 |
+
signal_alias = '.'.join(map(str, path))
|
877 |
+
|
878 |
+
udf_path = _path_to_udf_duckdb_path(path, path_to_udf_col_name)
|
879 |
+
if not udf_path:
|
880 |
+
# Re-route the path if it starts with an alias by pointing it to the actual path.
|
881 |
+
if first_subpath in col_aliases:
|
882 |
+
path = (*col_aliases[first_subpath], *rest_of_path)
|
883 |
+
self._validate_sort_path(path, schema)
|
884 |
+
path = self._leaf_path_to_duckdb_path(path, schema)
|
885 |
+
else:
|
886 |
+
path = udf_path
|
887 |
+
|
888 |
+
sort_sql = _select_sql(path, flatten=True, unnest=False)
|
889 |
+
has_repeated_field = any(subpath == PATH_WILDCARD for subpath in path)
|
890 |
+
if has_repeated_field:
|
891 |
+
sort_sql = (f'list_min({sort_sql})'
|
892 |
+
if sort_order == SortOrder.ASC else f'list_max({sort_sql})')
|
893 |
+
|
894 |
+
# Separate sort columns into two groups: those that need to be sorted before and after UDFs.
|
895 |
+
if udf_path:
|
896 |
+
sort_sql_after_udf.append(sort_sql)
|
897 |
+
else:
|
898 |
+
sort_sql_before_udf.append(sort_sql)
|
899 |
+
|
900 |
+
order_query = ''
|
901 |
+
if sort_sql_before_udf:
|
902 |
+
order_query = (f'ORDER BY {", ".join(sort_sql_before_udf)} '
|
903 |
+
f'{cast(SortOrder, sort_order).value}')
|
904 |
+
|
905 |
+
limit_query = ''
|
906 |
+
if limit:
|
907 |
+
if topk_udf_col:
|
908 |
+
limit_query = f'LIMIT {limit + (offset or 0)}'
|
909 |
+
elif sort_sql_after_udf:
|
910 |
+
limit_query = ''
|
911 |
+
else:
|
912 |
+
limit_query = f'LIMIT {limit} OFFSET {offset or 0}'
|
913 |
+
|
914 |
+
if not topk_udf_col and where_query:
|
915 |
+
total_num_rows = cast(tuple,
|
916 |
+
con.execute(f'SELECT COUNT(*) FROM t {where_query}').fetchone())[0]
|
917 |
+
|
918 |
+
# Fetch the data from DuckDB.
|
919 |
+
df = con.execute(f"""
|
920 |
+
SELECT {', '.join(select_queries)} FROM t
|
921 |
+
{where_query}
|
922 |
+
{order_query}
|
923 |
+
{limit_query}
|
924 |
+
""").df()
|
925 |
+
df = _replace_nan_with_none(df)
|
926 |
+
|
927 |
+
# Run UDFs on the transformed columns.
|
928 |
+
for udf_col in udf_columns:
|
929 |
+
signal = cast(Signal, udf_col.signal_udf)
|
930 |
+
signal_alias = udf_col.alias or _unique_alias(udf_col)
|
931 |
+
temp_signal_cols = columns_to_merge[signal_alias]
|
932 |
+
if len(temp_signal_cols) != 1:
|
933 |
+
raise ValueError(
|
934 |
+
f'Unable to compute signal {signal.name}. Signal UDFs only operate on leafs, but got '
|
935 |
+
f'{len(temp_signal_cols)} underlying columns that contain data related to {udf_col.path}.'
|
936 |
+
)
|
937 |
+
signal_column = list(temp_signal_cols.keys())[0]
|
938 |
+
input = df[signal_column]
|
939 |
+
|
940 |
+
with DebugTimer(f'Computing signal "{signal.name}"'):
|
941 |
+
signal.setup()
|
942 |
+
|
943 |
+
if isinstance(signal, VectorSignal):
|
944 |
+
embedding_signal = signal
|
945 |
+
vector_store = self._get_vector_db_index(embedding_signal.embedding, udf_col.path)
|
946 |
+
flat_keys = list(flatten_keys(df[UUID_COLUMN], input))
|
947 |
+
signal_out = sparse_to_dense_compute(
|
948 |
+
iter(flat_keys), lambda keys: embedding_signal.vector_compute(keys, vector_store))
|
949 |
+
# Add progress.
|
950 |
+
if task_step_id is not None:
|
951 |
+
signal_out = progress(
|
952 |
+
signal_out,
|
953 |
+
task_step_id=task_step_id,
|
954 |
+
estimated_len=len(flat_keys),
|
955 |
+
step_description=f'Computing {signal.key()}')
|
956 |
+
df[signal_column] = deep_unflatten(signal_out, input)
|
957 |
+
else:
|
958 |
+
num_rich_data = count_primitives(input)
|
959 |
+
flat_input = cast(Iterator[Optional[RichData]], deep_flatten(input))
|
960 |
+
signal_out = sparse_to_dense_compute(
|
961 |
+
flat_input, lambda x: signal.compute(cast(Iterable[RichData], x)))
|
962 |
+
# Add progress.
|
963 |
+
if task_step_id is not None:
|
964 |
+
signal_out = progress(
|
965 |
+
signal_out,
|
966 |
+
task_step_id=task_step_id,
|
967 |
+
estimated_len=num_rich_data,
|
968 |
+
step_description=f'Computing {signal.key()}')
|
969 |
+
signal_out_list = list(signal_out)
|
970 |
+
if signal_column in temp_column_to_offset_column:
|
971 |
+
offset_column_name, field = temp_column_to_offset_column[signal_column]
|
972 |
+
nested_spans: Iterable[Item] = df[offset_column_name]
|
973 |
+
flat_spans = deep_flatten(nested_spans)
|
974 |
+
for span, item in zip(flat_spans, signal_out_list):
|
975 |
+
_offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
|
976 |
+
|
977 |
+
if len(signal_out_list) != num_rich_data:
|
978 |
+
raise ValueError(
|
979 |
+
f'The signal generated {len(signal_out_list)} values but the input data had '
|
980 |
+
f"{num_rich_data} values. This means the signal either didn't generate a "
|
981 |
+
'"None" for a sparse output, or generated too many items.')
|
982 |
+
|
983 |
+
df[signal_column] = deep_unflatten(signal_out_list, input)
|
984 |
+
|
985 |
+
signal.teardown()
|
986 |
+
if not df.empty and (udf_filters or sort_sql_after_udf):
|
987 |
+
# Re-upload the udf outputs to duckdb so we can filter/sort on them.
|
988 |
+
rel = con.from_df(df)
|
989 |
+
|
990 |
+
if udf_filters:
|
991 |
+
udf_filter_queries = self._create_where(manifest, udf_filters)
|
992 |
+
if udf_filter_queries:
|
993 |
+
rel = rel.filter(' AND '.join(udf_filter_queries))
|
994 |
+
total_num_rows = cast(tuple, rel.count('*').fetchone())[0]
|
995 |
+
|
996 |
+
if sort_sql_after_udf:
|
997 |
+
if not sort_order:
|
998 |
+
raise ValueError('`sort_order` is required when `sort_by` is specified.')
|
999 |
+
rel = rel.order(f'{", ".join(sort_sql_after_udf)} {sort_order.value}')
|
1000 |
+
|
1001 |
+
if limit:
|
1002 |
+
rel = rel.limit(limit, offset or 0)
|
1003 |
+
|
1004 |
+
df = _replace_nan_with_none(rel.df())
|
1005 |
+
|
1006 |
+
if combine_columns:
|
1007 |
+
all_columns: dict[str, Column] = {}
|
1008 |
+
for col_dict in columns_to_merge.values():
|
1009 |
+
all_columns.update(col_dict)
|
1010 |
+
columns_to_merge = {'*': all_columns}
|
1011 |
+
|
1012 |
+
for offset_column, _ in temp_column_to_offset_column.values():
|
1013 |
+
del df[offset_column]
|
1014 |
+
|
1015 |
+
for final_col_name, temp_columns in columns_to_merge.items():
|
1016 |
+
for temp_col_name, column in temp_columns.items():
|
1017 |
+
if combine_columns:
|
1018 |
+
dest_path = _col_destination_path(column)
|
1019 |
+
spec = _split_path_into_subpaths_of_lists(dest_path)
|
1020 |
+
df[temp_col_name] = wrap_in_dicts(df[temp_col_name], spec)
|
1021 |
+
|
1022 |
+
# If the temp col name is the same as the final name, we can skip merging. This happens when
|
1023 |
+
# we select a source leaf column.
|
1024 |
+
if temp_col_name == final_col_name:
|
1025 |
+
continue
|
1026 |
+
|
1027 |
+
if final_col_name not in df:
|
1028 |
+
df[final_col_name] = df[temp_col_name]
|
1029 |
+
else:
|
1030 |
+
df[final_col_name] = merge_series(df[final_col_name], df[temp_col_name])
|
1031 |
+
del df[temp_col_name]
|
1032 |
+
|
1033 |
+
con.close()
|
1034 |
+
|
1035 |
+
if combine_columns:
|
1036 |
+
# Since we aliased every column to `*`, the object with have only '*' as the key. We need to
|
1037 |
+
# elevate the all the columns under '*'.
|
1038 |
+
df = pd.DataFrame.from_records(df['*'])
|
1039 |
+
|
1040 |
+
return SelectRowsResult(df, total_num_rows)
|
1041 |
+
|
1042 |
+
@override
|
1043 |
+
def select_rows_schema(self,
|
1044 |
+
columns: Optional[Sequence[ColumnId]] = None,
|
1045 |
+
sort_by: Optional[Sequence[Path]] = None,
|
1046 |
+
sort_order: Optional[SortOrder] = None,
|
1047 |
+
searches: Optional[Sequence[Search]] = None,
|
1048 |
+
combine_columns: bool = False) -> SelectRowsSchemaResult:
|
1049 |
+
"""Returns the schema of the result of `select_rows` above with the same arguments."""
|
1050 |
+
if not combine_columns:
|
1051 |
+
raise NotImplementedError(
|
1052 |
+
'select_rows_schema with combine_columns=False is not yet supported.')
|
1053 |
+
manifest = self.manifest()
|
1054 |
+
cols = self._normalize_columns(columns, manifest.data_schema)
|
1055 |
+
|
1056 |
+
# Always return the UUID column.
|
1057 |
+
col_paths = [col.path for col in cols]
|
1058 |
+
if (UUID_COLUMN,) not in col_paths:
|
1059 |
+
cols.append(column_from_identifier(UUID_COLUMN))
|
1060 |
+
|
1061 |
+
self._normalize_searches(searches, manifest)
|
1062 |
+
search_udfs = self._search_udfs(searches, manifest)
|
1063 |
+
cols.extend([search_udf.udf for search_udf in search_udfs])
|
1064 |
+
|
1065 |
+
udfs: list[SelectRowsSchemaUDF] = []
|
1066 |
+
col_schemas: list[Schema] = []
|
1067 |
+
for col in cols:
|
1068 |
+
dest_path = _col_destination_path(col)
|
1069 |
+
if col.signal_udf:
|
1070 |
+
udfs.append(SelectRowsSchemaUDF(path=dest_path, alias=col.alias))
|
1071 |
+
field = col.signal_udf.fields()
|
1072 |
+
field.signal = col.signal_udf.dict()
|
1073 |
+
elif manifest.data_schema.has_field(dest_path):
|
1074 |
+
field = manifest.data_schema.get_field(dest_path)
|
1075 |
+
else:
|
1076 |
+
# This column might refer to an output of a udf. We postpone validation to later.
|
1077 |
+
continue
|
1078 |
+
col_schemas.append(_make_schema_from_path(dest_path, field))
|
1079 |
+
|
1080 |
+
sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
|
1081 |
+
|
1082 |
+
search_results = [
|
1083 |
+
SearchResultInfo(search_path=search_udf.search_path, result_path=search_udf.output_path)
|
1084 |
+
for search_udf in search_udfs
|
1085 |
+
]
|
1086 |
+
|
1087 |
+
new_schema = merge_schemas(col_schemas)
|
1088 |
+
|
1089 |
+
# Now that we have the new schema, we can validate all the column selections.
|
1090 |
+
self._validate_columns(cols, manifest.data_schema, new_schema)
|
1091 |
+
|
1092 |
+
return SelectRowsSchemaResult(
|
1093 |
+
data_schema=new_schema, udfs=udfs, search_results=search_results, sorts=sort_results or None)
|
1094 |
+
|
1095 |
+
@override
|
1096 |
+
def media(self, item_id: str, leaf_path: Path) -> MediaResult:
|
1097 |
+
raise NotImplementedError('Media is not yet supported for the DuckDB implementation.')
|
1098 |
+
|
1099 |
+
def _get_span_from(self, path: PathTuple, manifest: DatasetManifest) -> Optional[PathTuple]:
|
1100 |
+
leafs = manifest.data_schema.leafs
|
1101 |
+
# Remove the value key so we can check the dtype from leafs.
|
1102 |
+
span_path = path[:-1] if path[-1] == VALUE_KEY else path
|
1103 |
+
is_span = (span_path in leafs and leafs[span_path].dtype == DataType.STRING_SPAN)
|
1104 |
+
return _derived_from_path(path, manifest.data_schema) if is_span else None
|
1105 |
+
|
1106 |
+
def _leaf_path_to_duckdb_path(self, leaf_path: PathTuple, schema: Schema) -> PathTuple:
|
1107 |
+
((_, duckdb_path),) = self._column_to_duckdb_paths(
|
1108 |
+
Column(leaf_path), schema, combine_columns=False, select_leaf=True)
|
1109 |
+
return duckdb_path
|
1110 |
+
|
1111 |
+
def _column_to_duckdb_paths(self,
|
1112 |
+
column: Column,
|
1113 |
+
schema: Schema,
|
1114 |
+
combine_columns: bool,
|
1115 |
+
select_leaf: bool = False) -> list[tuple[str, PathTuple]]:
|
1116 |
+
path = column.path
|
1117 |
+
parquet_manifests: list[Union[SourceManifest, SignalManifest]] = [
|
1118 |
+
self._source_manifest, *self._signal_manifests
|
1119 |
+
]
|
1120 |
+
duckdb_paths: list[tuple[str, PathTuple]] = []
|
1121 |
+
source_has_path = False
|
1122 |
+
|
1123 |
+
select_leaf = select_leaf or column.signal_udf is not None
|
1124 |
+
|
1125 |
+
for m in parquet_manifests:
|
1126 |
+
if not m.files:
|
1127 |
+
continue
|
1128 |
+
# Skip this parquet file if it doesn't contain the path.
|
1129 |
+
if not schema_contains_path(m.data_schema, path):
|
1130 |
+
continue
|
1131 |
+
|
1132 |
+
if isinstance(m, SourceManifest):
|
1133 |
+
source_has_path = True
|
1134 |
+
|
1135 |
+
if isinstance(m, SignalManifest) and source_has_path and not combine_columns:
|
1136 |
+
# Skip this signal if the source already has the path and we are not combining columns.
|
1137 |
+
continue
|
1138 |
+
|
1139 |
+
# Skip this parquet file if the path doesn't have a dtype.
|
1140 |
+
if select_leaf and not m.data_schema.get_field(path).dtype:
|
1141 |
+
continue
|
1142 |
+
|
1143 |
+
if isinstance(m, SignalManifest) and path == (UUID_COLUMN,):
|
1144 |
+
# Do not select UUID from the signal because it's already in the source.
|
1145 |
+
continue
|
1146 |
+
|
1147 |
+
duckdb_path = path
|
1148 |
+
parquet_id = 'source'
|
1149 |
+
|
1150 |
+
if isinstance(m, SignalManifest):
|
1151 |
+
duckdb_path = (m.parquet_id, *path[1:])
|
1152 |
+
parquet_id = m.parquet_id
|
1153 |
+
|
1154 |
+
duckdb_paths.append((parquet_id, duckdb_path))
|
1155 |
+
|
1156 |
+
if not duckdb_paths:
|
1157 |
+
# This path is probably a result of a udf. Make sure the result schema contains it.
|
1158 |
+
if not schema.has_field(path):
|
1159 |
+
raise ValueError(f'Invalid path "{path}": No manifest contains path. Valid paths: '
|
1160 |
+
f'{list(schema.leafs.keys())}')
|
1161 |
+
|
1162 |
+
return duckdb_paths
|
1163 |
+
|
1164 |
+
def _normalize_filters(self, filter_likes: Optional[Sequence[FilterLike]],
|
1165 |
+
col_aliases: dict[str, PathTuple], udf_aliases: dict[str, PathTuple],
|
1166 |
+
manifest: DatasetManifest) -> tuple[list[Filter], list[Filter]]:
|
1167 |
+
"""Normalize `FilterLike` to `Filter` and split into filters on source and filters on UDFs."""
|
1168 |
+
filter_likes = filter_likes or []
|
1169 |
+
filters: list[Filter] = []
|
1170 |
+
udf_filters: list[Filter] = []
|
1171 |
+
|
1172 |
+
for filter in filter_likes:
|
1173 |
+
# Normalize `FilterLike` to `Filter`.
|
1174 |
+
if not isinstance(filter, Filter):
|
1175 |
+
if len(filter) == 3:
|
1176 |
+
path, op, value = filter # type: ignore
|
1177 |
+
elif len(filter) == 2:
|
1178 |
+
path, op = filter # type: ignore
|
1179 |
+
value = None
|
1180 |
+
else:
|
1181 |
+
raise ValueError(f'Invalid filter: {filter}. Must be a tuple with 2 or 3 elements.')
|
1182 |
+
filter = Filter(path=normalize_path(path), op=op, value=value)
|
1183 |
+
|
1184 |
+
if str(filter.path[0]) in udf_aliases:
|
1185 |
+
udf_filters.append(filter)
|
1186 |
+
else:
|
1187 |
+
filters.append(filter)
|
1188 |
+
|
1189 |
+
self._validate_filters(filters, col_aliases, manifest)
|
1190 |
+
return filters, udf_filters
|
1191 |
+
|
1192 |
+
def _normalize_searches(self, searches: Optional[Sequence[Search]],
|
1193 |
+
manifest: DatasetManifest) -> None:
|
1194 |
+
"""Validate searches."""
|
1195 |
+
if not searches:
|
1196 |
+
return
|
1197 |
+
|
1198 |
+
for search in searches:
|
1199 |
+
search.path = normalize_path(search.path)
|
1200 |
+
field = manifest.data_schema.get_field(search.path)
|
1201 |
+
if field.dtype != DataType.STRING:
|
1202 |
+
raise ValueError(f'Invalid search path: {search.path}. '
|
1203 |
+
f'Must be a string field, got dtype {field.dtype}')
|
1204 |
+
|
1205 |
+
def _search_udfs(self, searches: Optional[Sequence[Search]],
|
1206 |
+
manifest: DatasetManifest) -> list[DuckDBSearchUDF]:
|
1207 |
+
searches = searches or []
|
1208 |
+
"""Create a UDF for each search for finding the location of the text with spans."""
|
1209 |
+
search_udfs: list[DuckDBSearchUDF] = []
|
1210 |
+
for search in searches:
|
1211 |
+
search_path = normalize_path(search.path)
|
1212 |
+
if search.query.type == 'keyword':
|
1213 |
+
udf = Column(path=search_path, signal_udf=SubstringSignal(query=search.query.search))
|
1214 |
+
search_udfs.append(
|
1215 |
+
DuckDBSearchUDF(
|
1216 |
+
udf=udf,
|
1217 |
+
search_path=search_path,
|
1218 |
+
output_path=(*_col_destination_path(udf), PATH_WILDCARD)))
|
1219 |
+
elif search.query.type == 'semantic' or search.query.type == 'concept':
|
1220 |
+
embedding = search.query.embedding
|
1221 |
+
if not embedding:
|
1222 |
+
raise ValueError(f'Please provide an embedding for semantic search. Got search: {search}')
|
1223 |
+
|
1224 |
+
try:
|
1225 |
+
manifest.data_schema.get_field((*search_path, embedding))
|
1226 |
+
except Exception as e:
|
1227 |
+
raise ValueError(
|
1228 |
+
f'Embedding {embedding} has not been computed. '
|
1229 |
+
f'Please compute the embedding index before issuing a {search.query.type} query.'
|
1230 |
+
) from e
|
1231 |
+
|
1232 |
+
search_signal: Optional[Signal] = None
|
1233 |
+
if search.query.type == 'semantic':
|
1234 |
+
search_signal = SemanticSimilaritySignal(
|
1235 |
+
query=search.query.search, embedding=search.query.embedding)
|
1236 |
+
elif search.query.type == 'concept':
|
1237 |
+
search_signal = ConceptScoreSignal(
|
1238 |
+
namespace=search.query.concept_namespace,
|
1239 |
+
concept_name=search.query.concept_name,
|
1240 |
+
embedding=search.query.embedding)
|
1241 |
+
|
1242 |
+
# Add the label UDF.
|
1243 |
+
concept_labels_signal = ConceptLabelsSignal(
|
1244 |
+
namespace=search.query.concept_namespace, concept_name=search.query.concept_name)
|
1245 |
+
concept_labels_udf = Column(path=search_path, signal_udf=concept_labels_signal)
|
1246 |
+
search_udfs.append(
|
1247 |
+
DuckDBSearchUDF(
|
1248 |
+
udf=concept_labels_udf,
|
1249 |
+
search_path=search_path,
|
1250 |
+
output_path=_col_destination_path(concept_labels_udf),
|
1251 |
+
sort=None))
|
1252 |
+
|
1253 |
+
udf = Column(path=search_path, signal_udf=search_signal)
|
1254 |
+
|
1255 |
+
output_path = _col_destination_path(udf)
|
1256 |
+
search_udfs.append(
|
1257 |
+
DuckDBSearchUDF(
|
1258 |
+
udf=udf,
|
1259 |
+
search_path=search_path,
|
1260 |
+
output_path=_col_destination_path(udf),
|
1261 |
+
sort=((*output_path, PATH_WILDCARD, 'score'), SortOrder.DESC)))
|
1262 |
+
else:
|
1263 |
+
raise ValueError(f'Unknown search operator {search.query.type}.')
|
1264 |
+
|
1265 |
+
return search_udfs
|
1266 |
+
|
1267 |
+
def _create_where(self,
|
1268 |
+
manifest: DatasetManifest,
|
1269 |
+
filters: list[Filter],
|
1270 |
+
searches: Optional[Sequence[Search]] = []) -> list[str]:
|
1271 |
+
if not filters and not searches:
|
1272 |
+
return []
|
1273 |
+
searches = searches or []
|
1274 |
+
sql_filter_queries: list[str] = []
|
1275 |
+
|
1276 |
+
# Add search where queries.
|
1277 |
+
for search in searches:
|
1278 |
+
duckdb_path = self._leaf_path_to_duckdb_path(
|
1279 |
+
normalize_path(search.path), manifest.data_schema)
|
1280 |
+
select_str = _select_sql(duckdb_path, flatten=False, unnest=False)
|
1281 |
+
if search.query.type == 'keyword':
|
1282 |
+
sql_op = 'ILIKE'
|
1283 |
+
query_val = _escape_like_value(search.query.search)
|
1284 |
+
elif search.query.type == 'semantic' or search.query.type == 'concept':
|
1285 |
+
# Semantic search and concepts don't yet filter.
|
1286 |
+
continue
|
1287 |
+
else:
|
1288 |
+
raise ValueError(f'Unknown search operator {search.query.type}.')
|
1289 |
+
|
1290 |
+
filter_query = f'{select_str} {sql_op} {query_val}'
|
1291 |
+
|
1292 |
+
sql_filter_queries.append(filter_query)
|
1293 |
+
|
1294 |
+
# Add filter where queries.
|
1295 |
+
binary_ops = set(BinaryOp)
|
1296 |
+
unary_ops = set(UnaryOp)
|
1297 |
+
list_ops = set(ListOp)
|
1298 |
+
for f in filters:
|
1299 |
+
duckdb_path = self._leaf_path_to_duckdb_path(f.path, manifest.data_schema)
|
1300 |
+
select_str = _select_sql(
|
1301 |
+
duckdb_path, flatten=True, unnest=False, span_from=self._get_span_from(f.path, manifest))
|
1302 |
+
is_array = any(subpath == PATH_WILDCARD for subpath in f.path)
|
1303 |
+
|
1304 |
+
nan_filter = ''
|
1305 |
+
field = manifest.data_schema.get_field(f.path)
|
1306 |
+
filter_nans = field.dtype and is_float(field.dtype)
|
1307 |
+
|
1308 |
+
if f.op in binary_ops:
|
1309 |
+
sql_op = BINARY_OP_TO_SQL[cast(BinaryOp, f.op)]
|
1310 |
+
filter_val = cast(FeatureValue, f.value)
|
1311 |
+
if isinstance(filter_val, str):
|
1312 |
+
filter_val = _escape_string_literal(filter_val)
|
1313 |
+
elif isinstance(filter_val, bytes):
|
1314 |
+
filter_val = _bytes_to_blob_literal(filter_val)
|
1315 |
+
else:
|
1316 |
+
filter_val = str(filter_val)
|
1317 |
+
if is_array:
|
1318 |
+
nan_filter = 'NOT isnan(x) AND' if filter_nans else ''
|
1319 |
+
filter_query = (f'len(list_filter({select_str}, '
|
1320 |
+
f'x -> {nan_filter} x {sql_op} {filter_val})) > 0')
|
1321 |
+
else:
|
1322 |
+
nan_filter = f'NOT isnan({select_str}) AND' if filter_nans else ''
|
1323 |
+
filter_query = f'{nan_filter} {select_str} {sql_op} {filter_val}'
|
1324 |
+
elif f.op in unary_ops:
|
1325 |
+
if f.op == UnaryOp.EXISTS:
|
1326 |
+
filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
|
1327 |
+
else:
|
1328 |
+
raise ValueError(f'Unary op: {f.op} is not yet supported')
|
1329 |
+
elif f.op in list_ops:
|
1330 |
+
if f.op == ListOp.IN:
|
1331 |
+
filter_list_val = cast(FeatureListValue, f.value)
|
1332 |
+
if not isinstance(filter_list_val, list):
|
1333 |
+
raise ValueError('filter with array value can only use the IN comparison')
|
1334 |
+
wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
|
1335 |
+
filter_val = f'({", ".join(wrapped_filter_val)})'
|
1336 |
+
filter_query = f'{select_str} IN {filter_val}'
|
1337 |
+
else:
|
1338 |
+
raise ValueError(f'List op: {f.op} is not yet supported')
|
1339 |
+
else:
|
1340 |
+
raise ValueError(f'Invalid filter op: {f.op}')
|
1341 |
+
sql_filter_queries.append(filter_query)
|
1342 |
+
return sql_filter_queries
|
1343 |
+
|
1344 |
+
def _execute(self, query: str) -> duckdb.DuckDBPyConnection:
|
1345 |
+
"""Execute a query in duckdb."""
|
1346 |
+
# FastAPI is multi-threaded so we have to create a thread-specific connection cursor to allow
|
1347 |
+
# these queries to be thread-safe.
|
1348 |
+
local_con = self.con.cursor()
|
1349 |
+
if not env('DEBUG', False):
|
1350 |
+
return local_con.execute(query)
|
1351 |
+
|
1352 |
+
# Debug mode.
|
1353 |
+
log('Executing:')
|
1354 |
+
log(query)
|
1355 |
+
with DebugTimer('Query'):
|
1356 |
+
return local_con.execute(query)
|
1357 |
+
|
1358 |
+
def _query(self, query: str) -> list[tuple]:
|
1359 |
+
result = self._execute(query)
|
1360 |
+
rows = result.fetchall()
|
1361 |
+
result.close()
|
1362 |
+
return rows
|
1363 |
+
|
1364 |
+
def _query_df(self, query: str) -> pd.DataFrame:
|
1365 |
+
"""Execute a query that returns a data frame."""
|
1366 |
+
result = self._execute(query)
|
1367 |
+
df = _replace_nan_with_none(result.df())
|
1368 |
+
result.close()
|
1369 |
+
return df
|
1370 |
+
|
1371 |
+
def _path_to_col(self, path: Path, quote_each_part: bool = True) -> str:
|
1372 |
+
"""Convert a path to a column name."""
|
1373 |
+
if isinstance(path, str):
|
1374 |
+
path = (path,)
|
1375 |
+
return '.'.join([
|
1376 |
+
f'{_escape_col_name(path_comp)}' if quote_each_part else str(path_comp) for path_comp in path
|
1377 |
+
])
|
1378 |
+
|
1379 |
+
@override
|
1380 |
+
def to_json(self, filepath: Union[str, pathlib.Path], jsonl: bool = True) -> None:
|
1381 |
+
self._execute(f"COPY t TO '{filepath}' (FORMAT JSON, ARRAY {'FALSE' if jsonl else 'TRUE'})")
|
1382 |
+
log(f'Dataset exported to {filepath}')
|
1383 |
+
|
1384 |
+
@override
|
1385 |
+
def to_pandas(self) -> pd.DataFrame:
|
1386 |
+
return self._query_df('SELECT * FROM t')
|
1387 |
+
|
1388 |
+
@override
|
1389 |
+
def to_csv(self, filepath: Union[str, pathlib.Path]) -> None:
|
1390 |
+
self._execute(f"COPY t TO '{filepath}' (FORMAT CSV, HEADER)")
|
1391 |
+
log(f'Dataset exported to {filepath}')
|
1392 |
+
|
1393 |
+
@override
|
1394 |
+
def to_parquet(self, filepath: Union[str, pathlib.Path]) -> None:
|
1395 |
+
self._execute(f"COPY t TO '{filepath}' (FORMAT PARQUET)")
|
1396 |
+
log(f'Dataset exported to {filepath}')
|
1397 |
+
|
1398 |
+
|
1399 |
+
def _escape_string_literal(string: str) -> str:
|
1400 |
+
string = string.replace("'", "''")
|
1401 |
+
return f"'{string}'"
|
1402 |
+
|
1403 |
+
|
1404 |
+
def _escape_col_name(col_name: str) -> str:
|
1405 |
+
col_name = col_name.replace('"', '""')
|
1406 |
+
return f'"{col_name}"'
|
1407 |
+
|
1408 |
+
|
1409 |
+
def _escape_like_value(value: str) -> str:
|
1410 |
+
value = value.replace('%', '\\%').replace('_', '\\_')
|
1411 |
+
return f"'%{value}%' ESCAPE '\\'"
|
1412 |
+
|
1413 |
+
|
1414 |
+
def _inner_select(sub_paths: list[PathTuple],
|
1415 |
+
inner_var: Optional[str] = None,
|
1416 |
+
empty: bool = False,
|
1417 |
+
span_from: Optional[PathTuple] = None) -> str:
|
1418 |
+
"""Recursively generate the inner select statement for a list of sub paths."""
|
1419 |
+
current_sub_path = sub_paths[0]
|
1420 |
+
lambda_var = inner_var + 'x' if inner_var else 'x'
|
1421 |
+
if not inner_var:
|
1422 |
+
lambda_var = 'x'
|
1423 |
+
inner_var = _escape_col_name(current_sub_path[0])
|
1424 |
+
current_sub_path = current_sub_path[1:]
|
1425 |
+
# Select the path inside structs. E.g. x['a']['b']['c'] given current_sub_path = [a, b, c].
|
1426 |
+
path_key = inner_var + ''.join([f'[{_escape_string_literal(p)}]' for p in current_sub_path])
|
1427 |
+
if len(sub_paths) == 1:
|
1428 |
+
if span_from:
|
1429 |
+
derived_col = _select_sql(span_from, flatten=False, unnest=False)
|
1430 |
+
path_key = (f'{derived_col}[{path_key}.{VALUE_KEY}.{TEXT_SPAN_START_FEATURE}+1:'
|
1431 |
+
f'{path_key}.{VALUE_KEY}.{TEXT_SPAN_END_FEATURE}]')
|
1432 |
+
return 'NULL' if empty else path_key
|
1433 |
+
return (f'list_transform({path_key}, {lambda_var} -> '
|
1434 |
+
f'{_inner_select(sub_paths[1:], lambda_var, empty, span_from)})')
|
1435 |
+
|
1436 |
+
|
1437 |
+
def _split_path_into_subpaths_of_lists(leaf_path: PathTuple) -> list[PathTuple]:
|
1438 |
+
"""Split a path into a subpath of lists.
|
1439 |
+
|
1440 |
+
E.g. [a, b, c, *, d, *, *] gets splits [[a, b, c], [d], [], []].
|
1441 |
+
"""
|
1442 |
+
sub_paths: list[PathTuple] = []
|
1443 |
+
offset = 0
|
1444 |
+
while offset <= len(leaf_path):
|
1445 |
+
new_offset = leaf_path.index(PATH_WILDCARD,
|
1446 |
+
offset) if PATH_WILDCARD in leaf_path[offset:] else len(leaf_path)
|
1447 |
+
sub_path = leaf_path[offset:new_offset]
|
1448 |
+
sub_paths.append(sub_path)
|
1449 |
+
offset = new_offset + 1
|
1450 |
+
return sub_paths
|
1451 |
+
|
1452 |
+
|
1453 |
+
def _select_sql(path: PathTuple,
|
1454 |
+
flatten: bool,
|
1455 |
+
unnest: bool,
|
1456 |
+
empty: bool = False,
|
1457 |
+
span_from: Optional[PathTuple] = None) -> str:
|
1458 |
+
"""Create a select column for a path.
|
1459 |
+
|
1460 |
+
Args:
|
1461 |
+
path: A path to a feature. E.g. ['a', 'b', 'c'].
|
1462 |
+
flatten: Whether to flatten the result.
|
1463 |
+
unnest: Whether to unnest the result.
|
1464 |
+
empty: Whether to return an empty list (used for embedding signals that don't need the data).
|
1465 |
+
span_from: The path this span is derived from. If specified, the span will be resolved
|
1466 |
+
to a substring of the original string.
|
1467 |
+
"""
|
1468 |
+
sub_paths = _split_path_into_subpaths_of_lists(path)
|
1469 |
+
selection = _inner_select(sub_paths, None, empty, span_from)
|
1470 |
+
# We only flatten when the result of a nested list to avoid segfault.
|
1471 |
+
is_result_nested_list = len(sub_paths) >= 3 # E.g. subPaths = [[a, b, c], *, *].
|
1472 |
+
if flatten and is_result_nested_list:
|
1473 |
+
selection = f'flatten({selection})'
|
1474 |
+
# We only unnest when the result is a list. // E.g. subPaths = [[a, b, c], *].
|
1475 |
+
is_result_a_list = len(sub_paths) >= 2
|
1476 |
+
if unnest and is_result_a_list:
|
1477 |
+
selection = f'unnest({selection})'
|
1478 |
+
return selection
|
1479 |
+
|
1480 |
+
|
1481 |
+
def read_source_manifest(dataset_path: str) -> SourceManifest:
|
1482 |
+
"""Read the manifest file."""
|
1483 |
+
with open_file(os.path.join(dataset_path, MANIFEST_FILENAME), 'r') as f:
|
1484 |
+
return SourceManifest.parse_raw(f.read())
|
1485 |
+
|
1486 |
+
|
1487 |
+
def _signal_dir(enriched_path: PathTuple) -> str:
|
1488 |
+
"""Get the filename prefix for a signal parquet file."""
|
1489 |
+
path_without_wildcards = (p for p in enriched_path if p != PATH_WILDCARD)
|
1490 |
+
return os.path.join(*path_without_wildcards)
|
1491 |
+
|
1492 |
+
|
1493 |
+
def split_column_name(column: str, split_name: str) -> str:
|
1494 |
+
"""Get the name of a split column."""
|
1495 |
+
return f'{column}.{split_name}'
|
1496 |
+
|
1497 |
+
|
1498 |
+
def split_parquet_prefix(column_name: str, splitter_name: str) -> str:
|
1499 |
+
"""Get the filename prefix for a split parquet file."""
|
1500 |
+
return f'{column_name}.{splitter_name}'
|
1501 |
+
|
1502 |
+
|
1503 |
+
def _bytes_to_blob_literal(bytes: bytes) -> str:
|
1504 |
+
"""Convert bytes to a blob literal."""
|
1505 |
+
escaped_hex = re.sub(r'(.{2})', r'\\x\1', bytes.hex())
|
1506 |
+
return f"'{escaped_hex}'::BLOB"
|
1507 |
+
|
1508 |
+
|
1509 |
+
class SignalManifest(BaseModel):
|
1510 |
+
"""The manifest that describes a signal computation including schema and parquet files."""
|
1511 |
+
# List of a parquet filepaths storing the data. The paths are relative to the manifest.
|
1512 |
+
files: list[str]
|
1513 |
+
|
1514 |
+
# An identifier for this parquet table. Will be used as the view name in SQL.
|
1515 |
+
parquet_id: str
|
1516 |
+
|
1517 |
+
data_schema: Schema
|
1518 |
+
signal: Signal
|
1519 |
+
|
1520 |
+
# The column path that this signal is derived from.
|
1521 |
+
enriched_path: PathTuple
|
1522 |
+
|
1523 |
+
# The name of the vector store. Present when the signal is an embedding.
|
1524 |
+
vector_store: Optional[str] = None
|
1525 |
+
|
1526 |
+
@validator('signal', pre=True)
|
1527 |
+
def parse_signal(cls, signal: dict) -> Signal:
|
1528 |
+
"""Parse a signal to its specific subclass instance."""
|
1529 |
+
return resolve_signal(signal)
|
1530 |
+
|
1531 |
+
|
1532 |
+
def _merge_cells(dest_cell: Item, source_cell: Item) -> Item:
|
1533 |
+
if source_cell is None or isinstance(source_cell, float) and math.isnan(source_cell):
|
1534 |
+
# Nothing to merge here (missing value).
|
1535 |
+
return dest_cell
|
1536 |
+
if isinstance(dest_cell, dict):
|
1537 |
+
if isinstance(source_cell, list):
|
1538 |
+
raise ValueError(f'Failed to merge cells. Destination is a dict ({dest_cell!r}), '
|
1539 |
+
f'but source is a list ({source_cell!r}).')
|
1540 |
+
if isinstance(source_cell, dict):
|
1541 |
+
res = {**dest_cell}
|
1542 |
+
for key, value in source_cell.items():
|
1543 |
+
res[key] = (value if key not in dest_cell else _merge_cells(dest_cell[key], value))
|
1544 |
+
return res
|
1545 |
+
else:
|
1546 |
+
return {VALUE_KEY: source_cell, **dest_cell}
|
1547 |
+
elif isinstance(dest_cell, list):
|
1548 |
+
if not isinstance(source_cell, list):
|
1549 |
+
raise ValueError('Failed to merge cells. Destination is a list, but source is not.')
|
1550 |
+
return [
|
1551 |
+
_merge_cells(dest_subcell, source_subcell)
|
1552 |
+
for dest_subcell, source_subcell in zip(dest_cell, source_cell)
|
1553 |
+
]
|
1554 |
+
else:
|
1555 |
+
# The destination is a primitive.
|
1556 |
+
if isinstance(source_cell, list):
|
1557 |
+
raise ValueError(f'Failed to merge cells. Destination is a primitive ({dest_cell!r}), '
|
1558 |
+
f'but source is a list ({source_cell!r}).')
|
1559 |
+
if isinstance(source_cell, dict):
|
1560 |
+
return {VALUE_KEY: dest_cell, **source_cell}
|
1561 |
+
else:
|
1562 |
+
# Primitives can be merged together if they are equal. This can happen if a user selects a
|
1563 |
+
# column that is the child of another.
|
1564 |
+
# NOTE: This can be removed if we fix https://github.com/lilacai/lilac/issues/166.
|
1565 |
+
if source_cell != dest_cell:
|
1566 |
+
raise ValueError(f'Cannot merge source "{source_cell!r}" into destination "{dest_cell!r}"')
|
1567 |
+
return dest_cell
|
1568 |
+
|
1569 |
+
|
1570 |
+
def merge_series(destination: pd.Series, source: pd.Series) -> list[Item]:
|
1571 |
+
"""Merge two series of values recursively."""
|
1572 |
+
return _merge_cells(destination.tolist(), source.tolist())
|
1573 |
+
|
1574 |
+
|
1575 |
+
def _unique_alias(column: Column) -> str:
|
1576 |
+
"""Get a unique alias for a selection column."""
|
1577 |
+
if column.signal_udf:
|
1578 |
+
return make_parquet_id(column.signal_udf, column.path)
|
1579 |
+
return '.'.join(map(str, column.path))
|
1580 |
+
|
1581 |
+
|
1582 |
+
def _path_contains(parent_path: PathTuple, child_path: PathTuple) -> bool:
|
1583 |
+
"""Check if a path contains another path."""
|
1584 |
+
if len(parent_path) > len(child_path):
|
1585 |
+
return False
|
1586 |
+
return all(parent_path[i] == child_path[i] for i in range(len(parent_path)))
|
1587 |
+
|
1588 |
+
|
1589 |
+
def _path_to_udf_duckdb_path(path: PathTuple,
|
1590 |
+
path_to_udf_col_name: dict[PathTuple, str]) -> Optional[PathTuple]:
|
1591 |
+
first_subpath, *rest_of_path = path
|
1592 |
+
for parent_path, udf_col_name in path_to_udf_col_name.items():
|
1593 |
+
# If the user selected udf(document.*.text) as "udf" and wanted to sort by "udf.len", we need to
|
1594 |
+
# sort by "udf.*.len" where the "*" came from the fact that the udf was applied to a list of
|
1595 |
+
# "text" fields.
|
1596 |
+
wildcards = [x for x in parent_path if x == PATH_WILDCARD]
|
1597 |
+
if _path_contains(parent_path, path):
|
1598 |
+
return (udf_col_name, *wildcards, *path[len(parent_path):])
|
1599 |
+
elif first_subpath == udf_col_name:
|
1600 |
+
return (udf_col_name, *wildcards, *rest_of_path)
|
1601 |
+
|
1602 |
+
return None
|
1603 |
+
|
1604 |
+
|
1605 |
+
def _col_destination_path(column: Column, is_computed_signal: Optional[bool] = False) -> PathTuple:
|
1606 |
+
"""Get the destination path where the output of this selection column will be stored."""
|
1607 |
+
source_path = column.path
|
1608 |
+
|
1609 |
+
if not column.signal_udf:
|
1610 |
+
return source_path
|
1611 |
+
|
1612 |
+
signal_key = column.signal_udf.key(is_computed_signal=is_computed_signal)
|
1613 |
+
# If we are enriching a value we should store the signal data in the value's parent.
|
1614 |
+
if source_path[-1] == VALUE_KEY:
|
1615 |
+
dest_path = (*source_path[:-1], signal_key)
|
1616 |
+
else:
|
1617 |
+
dest_path = (*source_path, signal_key)
|
1618 |
+
|
1619 |
+
return dest_path
|
1620 |
+
|
1621 |
+
|
1622 |
+
def _root_column(manifest: SignalManifest) -> str:
|
1623 |
+
"""Returns the root column of a signal manifest."""
|
1624 |
+
field_keys = manifest.data_schema.fields.keys()
|
1625 |
+
if len(field_keys) != 2:
|
1626 |
+
raise ValueError('Expected exactly two fields in signal manifest, '
|
1627 |
+
f'the row UUID and root this signal is enriching. Got {field_keys}.')
|
1628 |
+
return next(filter(lambda field: field != UUID_COLUMN, manifest.data_schema.fields.keys()))
|
1629 |
+
|
1630 |
+
|
1631 |
+
def _derived_from_path(path: PathTuple, schema: Schema) -> PathTuple:
|
1632 |
+
# Find the closest parent of `path` that is a signal root.
|
1633 |
+
for i in reversed(range(len(path))):
|
1634 |
+
sub_path = path[:i]
|
1635 |
+
if schema.get_field(sub_path).signal is not None:
|
1636 |
+
# Skip the signal name at the end to get the source path that was enriched.
|
1637 |
+
return sub_path[:-1]
|
1638 |
+
raise ValueError('Cannot find the source path for the enriched path: {path}')
|
1639 |
+
|
1640 |
+
|
1641 |
+
def _make_schema_from_path(path: PathTuple, field: Field) -> Schema:
|
1642 |
+
"""Returns a schema that contains only the given path."""
|
1643 |
+
for sub_path in reversed(path):
|
1644 |
+
if sub_path == PATH_WILDCARD:
|
1645 |
+
field = Field(repeated_field=field)
|
1646 |
+
else:
|
1647 |
+
field = Field(fields={sub_path: field})
|
1648 |
+
if not field.fields:
|
1649 |
+
raise ValueError(f'Invalid path: {path}. Must contain at least one field name.')
|
1650 |
+
return Schema(fields=field.fields)
|
1651 |
+
|
1652 |
+
|
1653 |
+
def _replace_nan_with_none(df: pd.DataFrame) -> pd.DataFrame:
|
1654 |
+
"""DuckDB returns np.nan for missing field in string column, replace with None for correctness."""
|
1655 |
+
# TODO(https://github.com/duckdb/duckdb/issues/4066): Remove this once duckdb fixes upstream.
|
1656 |
+
for col in df.columns:
|
1657 |
+
if is_object_dtype(df[col]):
|
1658 |
+
df[col].replace(np.nan, None, inplace=True)
|
1659 |
+
return df
|
1660 |
+
|
1661 |
+
|
1662 |
+
def _offset_any_span(offset: int, item: Item, schema: Field) -> None:
|
1663 |
+
"""Offsets any spans inplace by the given parent offset."""
|
1664 |
+
if schema.dtype == DataType.STRING_SPAN:
|
1665 |
+
item = cast(dict, item)
|
1666 |
+
item[VALUE_KEY][TEXT_SPAN_START_FEATURE] += offset
|
1667 |
+
item[VALUE_KEY][TEXT_SPAN_END_FEATURE] += offset
|
1668 |
+
if schema.fields:
|
1669 |
+
item = cast(dict, item)
|
1670 |
+
for key, sub_schema in schema.fields.items():
|
1671 |
+
_offset_any_span(offset, item[key], sub_schema)
|
1672 |
+
if schema.repeated_field:
|
1673 |
+
item = cast(list, item)
|
1674 |
+
for sub_item in item:
|
1675 |
+
_offset_any_span(offset, sub_item, schema.repeated_field)
|
1676 |
+
|
1677 |
+
|
1678 |
+
def _schema_has_spans(field: Field) -> bool:
|
1679 |
+
if field.dtype and field.dtype == DataType.STRING_SPAN:
|
1680 |
+
return True
|
1681 |
+
if field.fields:
|
1682 |
+
children_have_spans = any(_schema_has_spans(sub_field) for sub_field in field.fields.values())
|
1683 |
+
if children_have_spans:
|
1684 |
+
return True
|
1685 |
+
if field.repeated_field:
|
1686 |
+
return _schema_has_spans(field.repeated_field)
|
1687 |
+
return False
|
1688 |
+
|
1689 |
+
|
1690 |
+
def _normalize_bins(bins: Optional[Union[Sequence[Bin], Sequence[float]]]) -> Optional[list[Bin]]:
|
1691 |
+
if bins is None:
|
1692 |
+
return None
|
1693 |
+
if not isinstance(bins[0], (float, int)):
|
1694 |
+
return cast(list[Bin], bins)
|
1695 |
+
named_bins: list[Bin] = []
|
1696 |
+
for i in range(len(bins) + 1):
|
1697 |
+
start = cast(float, bins[i - 1]) if i > 0 else None
|
1698 |
+
end = cast(float, bins[i]) if i < len(bins) else None
|
1699 |
+
named_bins.append((str(i), start, end))
|
1700 |
+
return named_bins
|
1701 |
+
|
1702 |
+
|
1703 |
+
def _auto_bins(stats: StatsResult, num_bins: int) -> list[Bin]:
|
1704 |
+
min_val = cast(float, stats.min_val)
|
1705 |
+
max_val = cast(float, stats.max_val)
|
1706 |
+
bin_width = (max_val - min_val) / num_bins
|
1707 |
+
bins: list[Bin] = []
|
1708 |
+
for i in range(num_bins):
|
1709 |
+
start = None if i == 0 else min_val + i * bin_width
|
1710 |
+
end = None if i == num_bins - 1 else min_val + (i + 1) * bin_width
|
1711 |
+
bins.append((str(i), start, end))
|
1712 |
+
return bins
|
1713 |
+
|
1714 |
+
|
1715 |
+
def _settings_filepath(namespace: str, dataset_name: str) -> str:
|
1716 |
+
return os.path.join(
|
1717 |
+
get_dataset_output_dir(data_path(), namespace, dataset_name), DATASET_SETTINGS_FILENAME)
|
lilac/data/dataset_test_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests utils of for dataset_test."""
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from datetime import datetime
|
5 |
+
from typing import Optional, Type, cast
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from typing_extensions import Protocol
|
9 |
+
|
10 |
+
from ..embeddings.vector_store import VectorDBIndex
|
11 |
+
from ..schema import (
|
12 |
+
MANIFEST_FILENAME,
|
13 |
+
PARQUET_FILENAME_PREFIX,
|
14 |
+
VALUE_KEY,
|
15 |
+
DataType,
|
16 |
+
Field,
|
17 |
+
Item,
|
18 |
+
PathKey,
|
19 |
+
Schema,
|
20 |
+
SourceManifest,
|
21 |
+
)
|
22 |
+
from ..utils import get_dataset_output_dir, open_file
|
23 |
+
from .dataset import Dataset
|
24 |
+
from .dataset_utils import is_primitive, write_items_to_parquet
|
25 |
+
|
26 |
+
TEST_NAMESPACE = 'test_namespace'
|
27 |
+
TEST_DATASET_NAME = 'test_dataset'
|
28 |
+
|
29 |
+
|
30 |
+
def _infer_dtype(value: Item) -> DataType:
|
31 |
+
if isinstance(value, str):
|
32 |
+
return DataType.STRING
|
33 |
+
elif isinstance(value, bool):
|
34 |
+
return DataType.BOOLEAN
|
35 |
+
elif isinstance(value, bytes):
|
36 |
+
return DataType.BINARY
|
37 |
+
elif isinstance(value, float):
|
38 |
+
return DataType.FLOAT32
|
39 |
+
elif isinstance(value, int):
|
40 |
+
return DataType.INT32
|
41 |
+
elif isinstance(value, datetime):
|
42 |
+
return DataType.TIMESTAMP
|
43 |
+
else:
|
44 |
+
raise ValueError(f'Cannot infer dtype of primitive value: {value}')
|
45 |
+
|
46 |
+
|
47 |
+
def _infer_field(item: Item) -> Field:
|
48 |
+
"""Infer the schema from the items."""
|
49 |
+
if isinstance(item, dict):
|
50 |
+
fields: dict[str, Field] = {}
|
51 |
+
for k, v in item.items():
|
52 |
+
fields[k] = _infer_field(cast(Item, v))
|
53 |
+
dtype = None
|
54 |
+
if VALUE_KEY in fields:
|
55 |
+
dtype = fields[VALUE_KEY].dtype
|
56 |
+
del fields[VALUE_KEY]
|
57 |
+
return Field(fields=fields, dtype=dtype)
|
58 |
+
elif is_primitive(item):
|
59 |
+
return Field(dtype=_infer_dtype(item))
|
60 |
+
elif isinstance(item, list):
|
61 |
+
return Field(repeated_field=_infer_field(item[0]))
|
62 |
+
else:
|
63 |
+
raise ValueError(f'Cannot infer schema of item: {item}')
|
64 |
+
|
65 |
+
|
66 |
+
def _infer_schema(items: list[Item]) -> Schema:
|
67 |
+
"""Infer the schema from the items."""
|
68 |
+
schema = Schema(fields={})
|
69 |
+
for item in items:
|
70 |
+
field = _infer_field(item)
|
71 |
+
if not field.fields:
|
72 |
+
raise ValueError(f'Invalid schema of item. Expected an object, but got: {item}')
|
73 |
+
schema.fields = {**schema.fields, **field.fields}
|
74 |
+
return schema
|
75 |
+
|
76 |
+
|
77 |
+
class TestDataMaker(Protocol):
|
78 |
+
"""A function that creates a test dataset."""
|
79 |
+
|
80 |
+
def __call__(self, items: list[Item], schema: Optional[Schema] = None) -> Dataset:
|
81 |
+
"""Create a test dataset."""
|
82 |
+
...
|
83 |
+
|
84 |
+
|
85 |
+
def make_dataset(dataset_cls: Type[Dataset],
|
86 |
+
tmp_path: pathlib.Path,
|
87 |
+
items: list[Item],
|
88 |
+
schema: Optional[Schema] = None) -> Dataset:
|
89 |
+
"""Create a test dataset."""
|
90 |
+
schema = schema or _infer_schema(items)
|
91 |
+
_write_items(tmp_path, TEST_DATASET_NAME, items, schema)
|
92 |
+
return dataset_cls(TEST_NAMESPACE, TEST_DATASET_NAME)
|
93 |
+
|
94 |
+
|
95 |
+
def _write_items(tmpdir: pathlib.Path, dataset_name: str, items: list[Item],
|
96 |
+
schema: Schema) -> None:
|
97 |
+
"""Write the items JSON to the dataset format: manifest.json and parquet files."""
|
98 |
+
source_dir = get_dataset_output_dir(str(tmpdir), TEST_NAMESPACE, dataset_name)
|
99 |
+
os.makedirs(source_dir)
|
100 |
+
|
101 |
+
simple_parquet_files, _ = write_items_to_parquet(
|
102 |
+
items, source_dir, schema, filename_prefix=PARQUET_FILENAME_PREFIX, shard_index=0, num_shards=1)
|
103 |
+
manifest = SourceManifest(files=[simple_parquet_files], data_schema=schema)
|
104 |
+
with open_file(os.path.join(source_dir, MANIFEST_FILENAME), 'w') as f:
|
105 |
+
f.write(manifest.json(indent=2, exclude_none=True))
|
106 |
+
|
107 |
+
|
108 |
+
def enriched_item(value: Optional[Item] = None, metadata: dict[str, Item] = {}) -> Item:
|
109 |
+
"""Wrap a value in a dict with the value key."""
|
110 |
+
return {VALUE_KEY: value, **metadata}
|
111 |
+
|
112 |
+
|
113 |
+
def make_vector_index(vector_store: str, vector_dict: dict[PathKey,
|
114 |
+
list[list[float]]]) -> VectorDBIndex:
|
115 |
+
"""Make a vector index from a dictionary of vector keys to vectors."""
|
116 |
+
embeddings: list[np.ndarray] = []
|
117 |
+
spans: list[tuple[PathKey, list[tuple[int, int]]]] = []
|
118 |
+
for path_key, vectors in vector_dict.items():
|
119 |
+
vector_spans: list[tuple[int, int]] = []
|
120 |
+
for i, vector in enumerate(vectors):
|
121 |
+
embeddings.append(np.array(vector))
|
122 |
+
vector_spans.append((0, 0))
|
123 |
+
spans.append((path_key, vector_spans))
|
124 |
+
|
125 |
+
vector_index = VectorDBIndex(vector_store)
|
126 |
+
vector_index.add(spans, np.array(embeddings))
|
127 |
+
return vector_index
|
lilac/data/dataset_utils.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utilities for working with datasets."""
|
2 |
+
|
3 |
+
import gc
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import pprint
|
8 |
+
import secrets
|
9 |
+
from collections.abc import Iterable
|
10 |
+
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar, Union, cast
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import pyarrow as pa
|
14 |
+
|
15 |
+
from ..batch_utils import deep_flatten
|
16 |
+
from ..embeddings.vector_store import VectorDBIndex
|
17 |
+
from ..env import env
|
18 |
+
from ..parquet_writer import ParquetWriter
|
19 |
+
from ..schema import (
|
20 |
+
EMBEDDING_KEY,
|
21 |
+
PATH_WILDCARD,
|
22 |
+
TEXT_SPAN_END_FEATURE,
|
23 |
+
TEXT_SPAN_START_FEATURE,
|
24 |
+
UUID_COLUMN,
|
25 |
+
VALUE_KEY,
|
26 |
+
Field,
|
27 |
+
Item,
|
28 |
+
PathKey,
|
29 |
+
PathTuple,
|
30 |
+
Schema,
|
31 |
+
VectorKey,
|
32 |
+
field,
|
33 |
+
schema,
|
34 |
+
schema_to_arrow_schema,
|
35 |
+
)
|
36 |
+
from ..signals.signal import Signal
|
37 |
+
from ..utils import is_primitive, log, open_file
|
38 |
+
|
39 |
+
|
40 |
+
def _replace_embeddings_with_none(input: Union[Item, Item]) -> Union[Item, Item]:
|
41 |
+
if isinstance(input, np.ndarray):
|
42 |
+
return None
|
43 |
+
if isinstance(input, dict):
|
44 |
+
return {k: _replace_embeddings_with_none(v) for k, v in input.items()}
|
45 |
+
if isinstance(input, list):
|
46 |
+
return [_replace_embeddings_with_none(v) for v in input]
|
47 |
+
|
48 |
+
return input
|
49 |
+
|
50 |
+
|
51 |
+
def replace_embeddings_with_none(input: Union[Item, Item]) -> Item:
|
52 |
+
"""Replaces all embeddings with None."""
|
53 |
+
return cast(Item, _replace_embeddings_with_none(input))
|
54 |
+
|
55 |
+
|
56 |
+
def count_primitives(input: Union[Iterable, Iterator]) -> int:
|
57 |
+
"""Iterate through each element of the input, flattening each one, computing a count.
|
58 |
+
|
59 |
+
Sum the final set of counts. This is the important iterable not to exhaust.
|
60 |
+
"""
|
61 |
+
return sum((len(list(deep_flatten(i))) for i in input))
|
62 |
+
|
63 |
+
|
64 |
+
def _wrap_value_in_dict(input: Union[object, dict], props: PathTuple) -> Union[object, dict]:
|
65 |
+
# If the signal produced no value, or nan, we should return None so the parquet value is sparse.
|
66 |
+
if isinstance(input, float) and math.isnan(input):
|
67 |
+
input = None
|
68 |
+
for prop in reversed(props):
|
69 |
+
input = {prop: input}
|
70 |
+
return input
|
71 |
+
|
72 |
+
|
73 |
+
def _wrap_in_dicts(input: Union[object, Iterable[object]],
|
74 |
+
spec: list[PathTuple]) -> Union[object, Iterable[object]]:
|
75 |
+
"""Wraps an object or iterable in a dict according to the spec."""
|
76 |
+
props = spec[0] if spec else tuple()
|
77 |
+
if len(spec) == 1:
|
78 |
+
return _wrap_value_in_dict(input, props)
|
79 |
+
if input is None or isinstance(input, float) and math.isnan(input):
|
80 |
+
# Return empty dict for missing inputs.
|
81 |
+
return {}
|
82 |
+
res = [_wrap_in_dicts(elem, spec[1:]) for elem in cast(Iterable, input)]
|
83 |
+
return _wrap_value_in_dict(res, props)
|
84 |
+
|
85 |
+
|
86 |
+
def wrap_in_dicts(input: Iterable[object], spec: list[PathTuple]) -> Iterable[object]:
|
87 |
+
"""Wraps an object or iterable in a dict according to the spec."""
|
88 |
+
return [_wrap_in_dicts(elem, spec) for elem in input]
|
89 |
+
|
90 |
+
|
91 |
+
def _merge_field_into(schema: Field, destination: Field) -> None:
|
92 |
+
if isinstance(schema, Field):
|
93 |
+
destination.signal = destination.signal or schema.signal
|
94 |
+
destination.dtype = destination.dtype or schema.dtype
|
95 |
+
if schema.fields:
|
96 |
+
destination.fields = destination.fields or {}
|
97 |
+
for field_name, subfield in schema.fields.items():
|
98 |
+
if field_name not in destination.fields:
|
99 |
+
destination.fields[field_name] = subfield.copy(deep=True)
|
100 |
+
else:
|
101 |
+
_merge_field_into(subfield, destination.fields[field_name])
|
102 |
+
elif schema.repeated_field:
|
103 |
+
if not destination.repeated_field:
|
104 |
+
raise ValueError('Failed to merge schemas. Origin schema is repeated, but destination is not')
|
105 |
+
_merge_field_into(schema.repeated_field, destination.repeated_field)
|
106 |
+
else:
|
107 |
+
if destination.dtype != schema.dtype:
|
108 |
+
raise ValueError(f'Failed to merge schemas. Origin schema has dtype {schema.dtype}, '
|
109 |
+
f'but destination has dtype {destination.dtype}')
|
110 |
+
|
111 |
+
|
112 |
+
def merge_schemas(schemas: Sequence[Union[Schema, Field]]) -> Schema:
|
113 |
+
"""Merge a list of schemas."""
|
114 |
+
merged_schema = Schema(fields={})
|
115 |
+
for s in schemas:
|
116 |
+
_merge_field_into(cast(Field, s), cast(Field, merged_schema))
|
117 |
+
return merged_schema
|
118 |
+
|
119 |
+
|
120 |
+
def schema_contains_path(schema: Schema, path: PathTuple) -> bool:
|
121 |
+
"""Check if a schema contains a path."""
|
122 |
+
current_field = cast(Field, schema)
|
123 |
+
for path_part in path:
|
124 |
+
if path_part == PATH_WILDCARD:
|
125 |
+
if current_field.repeated_field is None:
|
126 |
+
return False
|
127 |
+
current_field = current_field.repeated_field
|
128 |
+
else:
|
129 |
+
if current_field.fields is None or path_part not in current_field.fields:
|
130 |
+
return False
|
131 |
+
current_field = current_field.fields[str(path_part)]
|
132 |
+
return True
|
133 |
+
|
134 |
+
|
135 |
+
def create_signal_schema(signal: Signal, source_path: PathTuple, current_schema: Schema) -> Schema:
|
136 |
+
"""Create a schema describing the enriched fields added an enrichment."""
|
137 |
+
leafs = current_schema.leafs
|
138 |
+
# Validate that the enrich fields are actually a valid leaf path.
|
139 |
+
if source_path not in leafs:
|
140 |
+
raise ValueError(f'"{source_path}" is not a valid leaf path. Leaf paths: {leafs.keys()}')
|
141 |
+
|
142 |
+
signal_schema = signal.fields()
|
143 |
+
signal_schema.signal = signal.dict()
|
144 |
+
|
145 |
+
enriched_schema = field(fields={signal.key(is_computed_signal=True): signal_schema})
|
146 |
+
|
147 |
+
for path_part in reversed(source_path):
|
148 |
+
if path_part == PATH_WILDCARD:
|
149 |
+
enriched_schema = Field(repeated_field=enriched_schema)
|
150 |
+
else:
|
151 |
+
enriched_schema = Field(fields={path_part: enriched_schema})
|
152 |
+
|
153 |
+
if not enriched_schema.fields:
|
154 |
+
raise ValueError('This should not happen since enriched_schema always has fields (see above)')
|
155 |
+
|
156 |
+
return schema({UUID_COLUMN: 'string', **cast(dict, enriched_schema.fields)})
|
157 |
+
|
158 |
+
|
159 |
+
def write_embeddings_to_disk(vector_store: str, uuids: Iterable[str], signal_items: Iterable[Item],
|
160 |
+
output_dir: str) -> None:
|
161 |
+
"""Write a set of embeddings to disk."""
|
162 |
+
|
163 |
+
def embedding_predicate(input: Any) -> bool:
|
164 |
+
return (isinstance(input, list) and len(input) > 0 and isinstance(input[0], dict) and
|
165 |
+
EMBEDDING_KEY in input[0])
|
166 |
+
|
167 |
+
path_keys = flatten_keys(uuids, signal_items, is_primitive_predicate=embedding_predicate)
|
168 |
+
all_embeddings = cast(Iterable[Item],
|
169 |
+
deep_flatten(signal_items, is_primitive_predicate=embedding_predicate))
|
170 |
+
|
171 |
+
embedding_vectors: list[np.ndarray] = []
|
172 |
+
all_spans: list[tuple[PathKey, list[tuple[int, int]]]] = []
|
173 |
+
for path_key, embeddings in zip(path_keys, all_embeddings):
|
174 |
+
if not path_key or not embeddings:
|
175 |
+
# Sparse embeddings may not have an embedding for every key.
|
176 |
+
continue
|
177 |
+
|
178 |
+
spans: list[tuple[int, int]] = []
|
179 |
+
for e in embeddings:
|
180 |
+
span = e[VALUE_KEY]
|
181 |
+
vector = e[EMBEDDING_KEY]
|
182 |
+
# We squeeze here because embedding functions can return outer dimensions of 1.
|
183 |
+
embedding_vectors.append(vector.reshape(-1))
|
184 |
+
spans.append((span[TEXT_SPAN_START_FEATURE], span[TEXT_SPAN_END_FEATURE]))
|
185 |
+
all_spans.append((path_key, spans))
|
186 |
+
embedding_matrix = np.array(embedding_vectors)
|
187 |
+
del path_keys, all_embeddings, embedding_vectors
|
188 |
+
gc.collect()
|
189 |
+
|
190 |
+
# Write to disk.
|
191 |
+
vector_index = VectorDBIndex(vector_store)
|
192 |
+
vector_index.add(all_spans, embedding_matrix)
|
193 |
+
vector_index.save(output_dir)
|
194 |
+
|
195 |
+
del vector_index
|
196 |
+
gc.collect()
|
197 |
+
|
198 |
+
|
199 |
+
def write_items_to_parquet(items: Iterable[Item], output_dir: str, schema: Schema,
|
200 |
+
filename_prefix: str, shard_index: int,
|
201 |
+
num_shards: int) -> tuple[str, int]:
|
202 |
+
"""Write a set of items to a parquet file, in columnar format."""
|
203 |
+
arrow_schema = schema_to_arrow_schema(schema)
|
204 |
+
out_filename = parquet_filename(filename_prefix, shard_index, num_shards)
|
205 |
+
filepath = os.path.join(output_dir, out_filename)
|
206 |
+
f = open_file(filepath, mode='wb')
|
207 |
+
writer = ParquetWriter(schema)
|
208 |
+
writer.open(f)
|
209 |
+
debug = env('DEBUG', False)
|
210 |
+
num_items = 0
|
211 |
+
for item in items:
|
212 |
+
# Add a UUID column.
|
213 |
+
if UUID_COLUMN not in item:
|
214 |
+
item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
|
215 |
+
if debug:
|
216 |
+
try:
|
217 |
+
_validate(item, arrow_schema)
|
218 |
+
except Exception as e:
|
219 |
+
raise ValueError(f'Error validating item: {json.dumps(item)}') from e
|
220 |
+
writer.write(item)
|
221 |
+
num_items += 1
|
222 |
+
writer.close()
|
223 |
+
f.close()
|
224 |
+
return out_filename, num_items
|
225 |
+
|
226 |
+
|
227 |
+
def _validate(item: Item, schema: pa.Schema) -> None:
|
228 |
+
# Try to parse the item using the inferred schema.
|
229 |
+
try:
|
230 |
+
pa.RecordBatch.from_pylist([item], schema=schema)
|
231 |
+
except pa.ArrowTypeError:
|
232 |
+
log('Failed to parse arrow item using the arrow schema.')
|
233 |
+
log('Item:')
|
234 |
+
log(pprint.pformat(item, indent=2))
|
235 |
+
log('Arrow schema:')
|
236 |
+
log(schema)
|
237 |
+
raise # Re-raise the same exception, same stacktrace.
|
238 |
+
|
239 |
+
|
240 |
+
def parquet_filename(prefix: str, shard_index: int, num_shards: int) -> str:
|
241 |
+
"""Return the filename for a parquet file."""
|
242 |
+
return f'{prefix}-{shard_index:05d}-of-{num_shards:05d}.parquet'
|
243 |
+
|
244 |
+
|
245 |
+
def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
|
246 |
+
is_primitive_predicate: Callable[[object], bool]) -> Iterator[VectorKey]:
|
247 |
+
if is_primitive_predicate(nested_input) or is_primitive(nested_input) or isinstance(
|
248 |
+
nested_input, dict):
|
249 |
+
yield (uuid, *location)
|
250 |
+
return
|
251 |
+
|
252 |
+
for i, input in enumerate(nested_input):
|
253 |
+
yield from _flatten_keys(uuid, input, [*location, i], is_primitive_predicate)
|
254 |
+
|
255 |
+
|
256 |
+
def flatten_keys(
|
257 |
+
uuids: Iterable[str],
|
258 |
+
nested_input: Iterable,
|
259 |
+
is_primitive_predicate: Callable[[object],
|
260 |
+
bool] = is_primitive) -> Iterator[Optional[VectorKey]]:
|
261 |
+
"""Flatten the uuid keys of a nested input."""
|
262 |
+
for uuid, input in zip(uuids, nested_input):
|
263 |
+
if input is None:
|
264 |
+
yield None
|
265 |
+
continue
|
266 |
+
yield from _flatten_keys(uuid, input, [], is_primitive_predicate)
|
267 |
+
|
268 |
+
|
269 |
+
Tin = TypeVar('Tin')
|
270 |
+
Tout = TypeVar('Tout')
|
271 |
+
|
272 |
+
|
273 |
+
def sparse_to_dense_compute(
|
274 |
+
sparse_input: Iterator[Optional[Tin]],
|
275 |
+
func: Callable[[Iterable[Tin]], Iterable[Tout]]) -> Iterator[Optional[Tout]]:
|
276 |
+
"""Densifies the input before calling the provided `func` and sparsifies the output."""
|
277 |
+
locations: list[int] = []
|
278 |
+
total_size: int = 0
|
279 |
+
|
280 |
+
def densify(x: Iterator[Optional[Tin]]) -> Iterator[Tin]:
|
281 |
+
nonlocal locations, total_size
|
282 |
+
for i, value in enumerate(x):
|
283 |
+
total_size += 1
|
284 |
+
if value is not None:
|
285 |
+
locations.append(i)
|
286 |
+
yield value
|
287 |
+
|
288 |
+
dense_input = densify(sparse_input)
|
289 |
+
dense_output = iter(func(dense_input))
|
290 |
+
index = 0
|
291 |
+
|
292 |
+
location_index = 0
|
293 |
+
|
294 |
+
while True:
|
295 |
+
try:
|
296 |
+
out = next(dense_output)
|
297 |
+
out_index = locations[location_index]
|
298 |
+
while index < out_index:
|
299 |
+
yield None
|
300 |
+
index += 1
|
301 |
+
yield out
|
302 |
+
location_index += 1
|
303 |
+
index += 1
|
304 |
+
except StopIteration:
|
305 |
+
while index < total_size:
|
306 |
+
yield None
|
307 |
+
index += 1
|
308 |
+
return
|
lilac/data/duckdb_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for duckdb."""
|
2 |
+
import os
|
3 |
+
|
4 |
+
import duckdb
|
5 |
+
|
6 |
+
from ..env import data_path, env
|
7 |
+
|
8 |
+
|
9 |
+
def duckdb_setup(con: duckdb.DuckDBPyConnection) -> str:
|
10 |
+
"""Setup DuckDB. This includes setting up the extensions directory and GCS access."""
|
11 |
+
con.execute(f"""
|
12 |
+
SET extension_directory='{os.path.join(data_path(), '.duckdb')}';
|
13 |
+
""")
|
14 |
+
|
15 |
+
con.install_extension('httpfs')
|
16 |
+
con.load_extension('httpfs')
|
17 |
+
|
18 |
+
if env('GCS_REGION'):
|
19 |
+
return f"""
|
20 |
+
SET s3_region='{env('GCS_REGION')}';
|
21 |
+
SET s3_access_key_id='{env('GCS_ACCESS_KEY')}';
|
22 |
+
SET s3_secret_access_key='{env('GCS_SECRET_KEY')}';
|
23 |
+
SET s3_endpoint='storage.googleapis.com';
|
24 |
+
"""
|
25 |
+
return ''
|
lilac/data_loader.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A data loader standalone binary. This should only be run as a script to load a dataset.
|
2 |
+
|
3 |
+
To run the source loader as a binary directly:
|
4 |
+
|
5 |
+
poetry run python -m lilac.data_loader \
|
6 |
+
--dataset_name=movies_dataset \
|
7 |
+
--output_dir=./data/ \
|
8 |
+
--config_path=./datasets/the_movies_dataset.json
|
9 |
+
"""
|
10 |
+
import os
|
11 |
+
import pathlib
|
12 |
+
import uuid
|
13 |
+
from typing import Iterable, Optional, Union
|
14 |
+
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
from .data.dataset import Dataset
|
18 |
+
from .data.dataset_utils import write_items_to_parquet
|
19 |
+
from .db_manager import get_dataset
|
20 |
+
from .env import data_path
|
21 |
+
from .schema import (
|
22 |
+
MANIFEST_FILENAME,
|
23 |
+
PARQUET_FILENAME_PREFIX,
|
24 |
+
UUID_COLUMN,
|
25 |
+
Field,
|
26 |
+
Item,
|
27 |
+
Schema,
|
28 |
+
SourceManifest,
|
29 |
+
field,
|
30 |
+
is_float,
|
31 |
+
)
|
32 |
+
from .sources.source import Source
|
33 |
+
from .tasks import TaskStepId, progress
|
34 |
+
from .utils import get_dataset_output_dir, log, open_file
|
35 |
+
|
36 |
+
|
37 |
+
def create_dataset(
|
38 |
+
namespace: str,
|
39 |
+
dataset_name: str,
|
40 |
+
source_config: Source,
|
41 |
+
) -> Dataset:
|
42 |
+
"""Load a dataset from a given source configuration."""
|
43 |
+
process_source(data_path(), namespace, dataset_name, source_config)
|
44 |
+
return get_dataset(namespace, dataset_name)
|
45 |
+
|
46 |
+
|
47 |
+
def process_source(base_dir: Union[str, pathlib.Path],
|
48 |
+
namespace: str,
|
49 |
+
dataset_name: str,
|
50 |
+
source: Source,
|
51 |
+
task_step_id: Optional[TaskStepId] = None) -> tuple[str, int]:
|
52 |
+
"""Process a source."""
|
53 |
+
output_dir = get_dataset_output_dir(base_dir, namespace, dataset_name)
|
54 |
+
|
55 |
+
source.setup()
|
56 |
+
source_schema = source.source_schema()
|
57 |
+
items = source.process()
|
58 |
+
|
59 |
+
# Add UUIDs and fix NaN in string columns.
|
60 |
+
items = normalize_items(items, source_schema.fields)
|
61 |
+
|
62 |
+
# Add progress.
|
63 |
+
items = progress(
|
64 |
+
items,
|
65 |
+
task_step_id=task_step_id,
|
66 |
+
estimated_len=source_schema.num_items,
|
67 |
+
step_description=f'Reading from source {source.name}...')
|
68 |
+
|
69 |
+
# Filter out the `None`s after progress.
|
70 |
+
items = (item for item in items if item is not None)
|
71 |
+
|
72 |
+
data_schema = Schema(fields={**source_schema.fields, UUID_COLUMN: field('string')})
|
73 |
+
filepath, num_items = write_items_to_parquet(
|
74 |
+
items=items,
|
75 |
+
output_dir=output_dir,
|
76 |
+
schema=data_schema,
|
77 |
+
filename_prefix=PARQUET_FILENAME_PREFIX,
|
78 |
+
shard_index=0,
|
79 |
+
num_shards=1)
|
80 |
+
|
81 |
+
filenames = [os.path.basename(filepath)]
|
82 |
+
manifest = SourceManifest(files=filenames, data_schema=data_schema, images=None)
|
83 |
+
with open_file(os.path.join(output_dir, MANIFEST_FILENAME), 'w') as f:
|
84 |
+
f.write(manifest.json(indent=2, exclude_none=True))
|
85 |
+
log(f'Dataset "{dataset_name}" written to {output_dir}')
|
86 |
+
|
87 |
+
return output_dir, num_items
|
88 |
+
|
89 |
+
|
90 |
+
def normalize_items(items: Iterable[Item], fields: dict[str, Field]) -> Item:
|
91 |
+
"""Sanitize items by removing NaNs and NaTs."""
|
92 |
+
replace_nan_fields = [
|
93 |
+
field_name for field_name, field in fields.items() if field.dtype and not is_float(field.dtype)
|
94 |
+
]
|
95 |
+
for item in items:
|
96 |
+
if item is None:
|
97 |
+
yield item
|
98 |
+
continue
|
99 |
+
|
100 |
+
# Add row uuid if it doesn't exist.
|
101 |
+
if UUID_COLUMN not in item:
|
102 |
+
item[UUID_COLUMN] = uuid.uuid4().hex
|
103 |
+
|
104 |
+
# Fix NaN values.
|
105 |
+
for field_name in replace_nan_fields:
|
106 |
+
item_value = item.get(field_name)
|
107 |
+
if item_value and pd.isna(item_value):
|
108 |
+
item[field_name] = None
|
109 |
+
|
110 |
+
yield item
|
lilac/db_manager.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Manages mapping the dataset name to the database instance."""
|
2 |
+
import os
|
3 |
+
import threading
|
4 |
+
from typing import Type
|
5 |
+
|
6 |
+
from .data.dataset import Dataset
|
7 |
+
|
8 |
+
_DEFAULT_DATASET_CLS: Type[Dataset]
|
9 |
+
|
10 |
+
_CACHED_DATASETS: dict[str, Dataset] = {}
|
11 |
+
|
12 |
+
_db_lock = threading.Lock()
|
13 |
+
|
14 |
+
|
15 |
+
def get_dataset(namespace: str, dataset_name: str) -> Dataset:
|
16 |
+
"""Get the dataset instance."""
|
17 |
+
if not _DEFAULT_DATASET_CLS:
|
18 |
+
raise ValueError('Default dataset class not set.')
|
19 |
+
cache_key = f'{namespace}/{dataset_name}'
|
20 |
+
# https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
|
21 |
+
inside_test = 'PYTEST_CURRENT_TEST' in os.environ
|
22 |
+
with _db_lock:
|
23 |
+
if cache_key not in _CACHED_DATASETS or inside_test:
|
24 |
+
_CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS(
|
25 |
+
namespace=namespace, dataset_name=dataset_name)
|
26 |
+
return _CACHED_DATASETS[cache_key]
|
27 |
+
|
28 |
+
|
29 |
+
def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None:
|
30 |
+
"""Remove the dataset from the db manager cache."""
|
31 |
+
cache_key = f'{namespace}/{dataset_name}'
|
32 |
+
with _db_lock:
|
33 |
+
if cache_key in _CACHED_DATASETS:
|
34 |
+
del _CACHED_DATASETS[cache_key]
|
35 |
+
|
36 |
+
|
37 |
+
# TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a
|
38 |
+
# circular dependency.
|
39 |
+
def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None:
|
40 |
+
"""Set the default dataset class."""
|
41 |
+
global _DEFAULT_DATASET_CLS
|
42 |
+
_DEFAULT_DATASET_CLS = dataset_cls
|
lilac/embeddings/__init__.py
ADDED
File without changes
|
lilac/embeddings/cohere.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Cohere embeddings."""
|
2 |
+
from typing import TYPE_CHECKING, Iterable, cast
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from typing_extensions import override
|
6 |
+
|
7 |
+
from ..env import env
|
8 |
+
from ..schema import Item, RichData
|
9 |
+
from ..signals.signal import TextEmbeddingSignal
|
10 |
+
from ..signals.splitters.chunk_splitter import split_text
|
11 |
+
from .embedding import compute_split_embeddings
|
12 |
+
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
from cohere import Client
|
15 |
+
|
16 |
+
NUM_PARALLEL_REQUESTS = 10
|
17 |
+
COHERE_BATCH_SIZE = 96
|
18 |
+
|
19 |
+
|
20 |
+
class Cohere(TextEmbeddingSignal):
|
21 |
+
"""Computes embeddings using Cohere's embedding API.
|
22 |
+
|
23 |
+
<br>**Important**: This will send data to an external server!
|
24 |
+
|
25 |
+
<br>To use this signal, you must get a Cohere API key from
|
26 |
+
[cohere.com/embed](https://cohere.com/embed) and add it to your .env.local.
|
27 |
+
|
28 |
+
<br>For details on pricing, see: https://cohere.com/pricing.
|
29 |
+
"""
|
30 |
+
|
31 |
+
name = 'cohere'
|
32 |
+
display_name = 'Cohere Embeddings'
|
33 |
+
|
34 |
+
_model: 'Client'
|
35 |
+
|
36 |
+
@override
|
37 |
+
def setup(self) -> None:
|
38 |
+
"""Validate that the api key and python package exists in environment."""
|
39 |
+
api_key = env('COHERE_API_KEY')
|
40 |
+
if not api_key:
|
41 |
+
raise ValueError('`COHERE_API_KEY` environment variable not set.')
|
42 |
+
try:
|
43 |
+
import cohere
|
44 |
+
self._model = cohere.Client(api_key, max_retries=10)
|
45 |
+
except ImportError:
|
46 |
+
raise ImportError('Could not import the "cohere" python package. '
|
47 |
+
'Please install it with `pip install cohere`.')
|
48 |
+
|
49 |
+
@override
|
50 |
+
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
|
51 |
+
"""Compute embeddings for the given documents."""
|
52 |
+
|
53 |
+
def embed_fn(texts: list[str]) -> list[np.ndarray]:
|
54 |
+
return self._model.embed(texts, truncate='END').embeddings
|
55 |
+
|
56 |
+
docs = cast(Iterable[str], docs)
|
57 |
+
split_fn = split_text if self._split else None
|
58 |
+
yield from compute_split_embeddings(
|
59 |
+
docs, COHERE_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
|
lilac/embeddings/default_vector_stores.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Registers all vector stores."""
|
2 |
+
from .vector_store import register_vector_store
|
3 |
+
from .vector_store_hnsw import HNSWVectorStore
|
4 |
+
from .vector_store_numpy import NumpyVectorStore
|
5 |
+
|
6 |
+
|
7 |
+
def register_default_vector_stores() -> None:
|
8 |
+
"""Register all the default vector stores."""
|
9 |
+
register_vector_store(HNSWVectorStore)
|
10 |
+
register_vector_store(NumpyVectorStore)
|
lilac/embeddings/embedding.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Embedding registry."""
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from typing import Callable, Generator, Iterable, Iterator, Optional, Union, cast
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from pydantic import StrictStr
|
7 |
+
from sklearn.preprocessing import normalize
|
8 |
+
|
9 |
+
from ..schema import (
|
10 |
+
EMBEDDING_KEY,
|
11 |
+
TEXT_SPAN_END_FEATURE,
|
12 |
+
TEXT_SPAN_START_FEATURE,
|
13 |
+
VALUE_KEY,
|
14 |
+
Item,
|
15 |
+
RichData,
|
16 |
+
SpanVector,
|
17 |
+
lilac_embedding,
|
18 |
+
)
|
19 |
+
from ..signals.signal import TextEmbeddingSignal, get_signal_by_type
|
20 |
+
from ..signals.splitters.chunk_splitter import TextChunk
|
21 |
+
from ..utils import chunks
|
22 |
+
|
23 |
+
EmbeddingId = Union[StrictStr, TextEmbeddingSignal]
|
24 |
+
|
25 |
+
EmbedFn = Callable[[Iterable[RichData]], Iterator[list[SpanVector]]]
|
26 |
+
|
27 |
+
|
28 |
+
def get_embed_fn(embedding_name: str, split: bool) -> EmbedFn:
|
29 |
+
"""Return a function that returns the embedding matrix for the given embedding signal."""
|
30 |
+
embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal)
|
31 |
+
embedding = embedding_cls(split=split)
|
32 |
+
embedding.setup()
|
33 |
+
|
34 |
+
def _embed_fn(data: Iterable[RichData]) -> Iterator[list[SpanVector]]:
|
35 |
+
items = embedding.compute(data)
|
36 |
+
|
37 |
+
for item in items:
|
38 |
+
if not item:
|
39 |
+
raise ValueError('Embedding signal returned None.')
|
40 |
+
|
41 |
+
yield [{
|
42 |
+
'vector': item_val[EMBEDDING_KEY].reshape(-1),
|
43 |
+
'span':
|
44 |
+
(item_val[VALUE_KEY][TEXT_SPAN_START_FEATURE], item_val[VALUE_KEY][TEXT_SPAN_END_FEATURE])
|
45 |
+
} for item_val in item]
|
46 |
+
|
47 |
+
return _embed_fn
|
48 |
+
|
49 |
+
|
50 |
+
def compute_split_embeddings(docs: Iterable[str],
|
51 |
+
batch_size: int,
|
52 |
+
embed_fn: Callable[[list[str]], list[np.ndarray]],
|
53 |
+
split_fn: Optional[Callable[[str], list[TextChunk]]] = None,
|
54 |
+
num_parallel_requests: int = 1) -> Generator[Item, None, None]:
|
55 |
+
"""Compute text embeddings in batches of chunks, using the provided splitter and embedding fn."""
|
56 |
+
pool = ThreadPoolExecutor()
|
57 |
+
|
58 |
+
def _splitter(doc: str) -> list[TextChunk]:
|
59 |
+
if not doc:
|
60 |
+
return []
|
61 |
+
if split_fn:
|
62 |
+
return split_fn(doc)
|
63 |
+
else:
|
64 |
+
# Return a single chunk that spans the entire document.
|
65 |
+
return [(doc, (0, len(doc)))]
|
66 |
+
|
67 |
+
num_docs = 0
|
68 |
+
|
69 |
+
def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
|
70 |
+
"""Split a batch of documents into chunks and yield them."""
|
71 |
+
nonlocal num_docs
|
72 |
+
for i, doc in enumerate(docs):
|
73 |
+
num_docs += 1
|
74 |
+
chunks = _splitter(doc)
|
75 |
+
for chunk in chunks:
|
76 |
+
yield (i, chunk)
|
77 |
+
|
78 |
+
doc_chunks = _flat_split_batch_docs(docs)
|
79 |
+
items_to_yield: Optional[list[Item]] = None
|
80 |
+
current_index = 0
|
81 |
+
|
82 |
+
mega_batch_size = batch_size * num_parallel_requests
|
83 |
+
|
84 |
+
for batch in chunks(doc_chunks, mega_batch_size):
|
85 |
+
texts = [text for _, (text, _) in batch]
|
86 |
+
embeddings: list[np.ndarray] = []
|
87 |
+
|
88 |
+
for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
|
89 |
+
embeddings.extend(x)
|
90 |
+
matrix = cast(np.ndarray, normalize(np.array(embeddings, dtype=np.float32)))
|
91 |
+
# np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
|
92 |
+
embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
|
93 |
+
for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
|
94 |
+
embedding = embedding.reshape(-1)
|
95 |
+
if index == current_index:
|
96 |
+
if items_to_yield is None:
|
97 |
+
items_to_yield = []
|
98 |
+
items_to_yield.append(lilac_embedding(start, end, embedding))
|
99 |
+
else:
|
100 |
+
yield items_to_yield
|
101 |
+
current_index += 1
|
102 |
+
while current_index < index:
|
103 |
+
yield None
|
104 |
+
current_index += 1
|
105 |
+
items_to_yield = [lilac_embedding(start, end, embedding)]
|
106 |
+
|
107 |
+
while current_index < num_docs:
|
108 |
+
yield items_to_yield
|
109 |
+
items_to_yield = None
|
110 |
+
current_index += 1
|
lilac/embeddings/gte.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
|
2 |
+
from typing import TYPE_CHECKING, Iterable, cast
|
3 |
+
|
4 |
+
from typing_extensions import override
|
5 |
+
|
6 |
+
from ..schema import Item, RichData
|
7 |
+
from ..signals.signal import TextEmbeddingSignal
|
8 |
+
from ..signals.splitters.chunk_splitter import split_text
|
9 |
+
from .embedding import compute_split_embeddings
|
10 |
+
from .transformer_utils import get_model
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
pass
|
14 |
+
|
15 |
+
# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
|
16 |
+
GTE_SMALL = 'thenlper/gte-small'
|
17 |
+
GTE_BASE = 'thenlper/gte-base'
|
18 |
+
|
19 |
+
# Maps a tuple of model name and device to the optimal batch size, found empirically.
|
20 |
+
_OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
|
21 |
+
GTE_SMALL: {
|
22 |
+
'': 64, # Default batch size.
|
23 |
+
'mps': 256,
|
24 |
+
},
|
25 |
+
GTE_BASE: {
|
26 |
+
'': 64, # Default batch size.
|
27 |
+
'mps': 128,
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class GTESmall(TextEmbeddingSignal):
|
33 |
+
"""Computes Gegeral Text Embeddings (GTE).
|
34 |
+
|
35 |
+
<br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-small)
|
36 |
+
for details.
|
37 |
+
"""
|
38 |
+
|
39 |
+
name = 'gte-small'
|
40 |
+
display_name = 'Gegeral Text Embeddings (small)'
|
41 |
+
|
42 |
+
_model_name = GTE_SMALL
|
43 |
+
|
44 |
+
@override
|
45 |
+
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
|
46 |
+
"""Call the embedding function."""
|
47 |
+
batch_size, model = get_model(self._model_name, _OPTIMAL_BATCH_SIZES[self._model_name])
|
48 |
+
embed_fn = model.encode
|
49 |
+
split_fn = split_text if self._split else None
|
50 |
+
docs = cast(Iterable[str], docs)
|
51 |
+
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
|
52 |
+
|
53 |
+
|
54 |
+
class GTEBase(GTESmall):
|
55 |
+
"""Computes Gegeral Text Embeddings (GTE).
|
56 |
+
|
57 |
+
<br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-base)
|
58 |
+
for details.
|
59 |
+
"""
|
60 |
+
name = 'gte-base'
|
61 |
+
display_name = 'Gegeral Text Embeddings (base)'
|
62 |
+
|
63 |
+
_model_name = GTE_BASE
|
lilac/embeddings/openai.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""OpenAI embeddings."""
|
2 |
+
from typing import TYPE_CHECKING, Any, Iterable, cast
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
from typing_extensions import override
|
7 |
+
|
8 |
+
from ..env import env
|
9 |
+
from ..schema import Item, RichData
|
10 |
+
from ..signals.signal import TextEmbeddingSignal
|
11 |
+
from ..signals.splitters.chunk_splitter import split_text
|
12 |
+
from .embedding import compute_split_embeddings
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
import openai
|
16 |
+
|
17 |
+
NUM_PARALLEL_REQUESTS = 10
|
18 |
+
OPENAI_BATCH_SIZE = 128
|
19 |
+
EMBEDDING_MODEL = 'text-embedding-ada-002'
|
20 |
+
|
21 |
+
|
22 |
+
class OpenAI(TextEmbeddingSignal):
|
23 |
+
"""Computes embeddings using OpenAI's embedding API.
|
24 |
+
|
25 |
+
<br>**Important**: This will send data to an external server!
|
26 |
+
|
27 |
+
<br>To use this signal, you must get an OpenAI API key from
|
28 |
+
[platform.openai.com](https://platform.openai.com/) and add it to your .env.local.
|
29 |
+
|
30 |
+
<br>For details on pricing, see: https://openai.com/pricing.
|
31 |
+
"""
|
32 |
+
|
33 |
+
name = 'openai'
|
34 |
+
display_name = 'OpenAI Embeddings'
|
35 |
+
|
36 |
+
_model: type['openai.Embedding']
|
37 |
+
|
38 |
+
@override
|
39 |
+
def setup(self) -> None:
|
40 |
+
api_key = env('OPENAI_API_KEY')
|
41 |
+
if not api_key:
|
42 |
+
raise ValueError('`OPENAI_API_KEY` environment variable not set.')
|
43 |
+
try:
|
44 |
+
import openai
|
45 |
+
openai.api_key = api_key
|
46 |
+
self._model = openai.Embedding
|
47 |
+
except ImportError:
|
48 |
+
raise ImportError('Could not import the "openai" python package. '
|
49 |
+
'Please install it with `pip install openai`.')
|
50 |
+
|
51 |
+
@override
|
52 |
+
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
|
53 |
+
"""Compute embeddings for the given documents."""
|
54 |
+
|
55 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
|
56 |
+
def embed_fn(texts: list[str]) -> list[np.ndarray]:
|
57 |
+
|
58 |
+
# Replace newlines, which can negatively affect performance.
|
59 |
+
# See https://github.com/search?q=repo%3Aopenai%2Fopenai-python+replace+newlines&type=code
|
60 |
+
texts = [text.replace('\n', ' ') for text in texts]
|
61 |
+
|
62 |
+
response: Any = self._model.create(input=texts, model=EMBEDDING_MODEL)
|
63 |
+
return [np.array(embedding['embedding'], dtype=np.float32) for embedding in response['data']]
|
64 |
+
|
65 |
+
docs = cast(Iterable[str], docs)
|
66 |
+
split_fn = split_text if self._split else None
|
67 |
+
yield from compute_split_embeddings(
|
68 |
+
docs, OPENAI_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
|
lilac/embeddings/palm.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""PaLM embeddings."""
|
2 |
+
from typing import TYPE_CHECKING, Iterable, cast
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
6 |
+
from typing_extensions import override
|
7 |
+
|
8 |
+
from ..env import env
|
9 |
+
from ..schema import Item, RichData
|
10 |
+
from ..signals.signal import TextEmbeddingSignal
|
11 |
+
from ..signals.splitters.chunk_splitter import split_text
|
12 |
+
from .embedding import compute_split_embeddings
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
import google.generativeai as palm
|
16 |
+
|
17 |
+
PALM_BATCH_SIZE = 1 # PaLM API only supports batch size 1.
|
18 |
+
NUM_PARALLEL_REQUESTS = 256 # Because batch size is 1, we can send many requests in parallel.
|
19 |
+
EMBEDDING_MODEL = 'models/embedding-gecko-001'
|
20 |
+
|
21 |
+
|
22 |
+
class PaLM(TextEmbeddingSignal):
|
23 |
+
"""Computes embeddings using PaLM's embedding API.
|
24 |
+
|
25 |
+
<br>**Important**: This will send data to an external server!
|
26 |
+
|
27 |
+
<br>To use this signal, you must get a PaLM API key from
|
28 |
+
[makersuite.google.com](https://makersuite.google.com/app/apikey) and add it to your .env.local.
|
29 |
+
"""
|
30 |
+
|
31 |
+
name = 'palm'
|
32 |
+
display_name = 'PaLM Embeddings'
|
33 |
+
|
34 |
+
_model: 'palm.generate_embeddings'
|
35 |
+
|
36 |
+
@override
|
37 |
+
def setup(self) -> None:
|
38 |
+
api_key = env('PALM_API_KEY')
|
39 |
+
if not api_key:
|
40 |
+
raise ValueError('`PALM_API_KEY` environment variable not set.')
|
41 |
+
try:
|
42 |
+
import google.generativeai as palm
|
43 |
+
palm.configure(api_key=api_key)
|
44 |
+
self._model = palm.generate_embeddings
|
45 |
+
except ImportError:
|
46 |
+
raise ImportError('Could not import the "google.generativeai" python package. '
|
47 |
+
'Please install it with `pip install google-generativeai`.')
|
48 |
+
|
49 |
+
@override
|
50 |
+
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
|
51 |
+
"""Compute embeddings for the given documents."""
|
52 |
+
|
53 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
|
54 |
+
def embed_fn(texts: list[str]) -> list[np.ndarray]:
|
55 |
+
assert len(texts) == 1, 'PaLM API only supports batch size 1.'
|
56 |
+
response = self._model(model=EMBEDDING_MODEL, text=texts[0])
|
57 |
+
return [np.array(response['embedding'], dtype=np.float32)]
|
58 |
+
|
59 |
+
docs = cast(Iterable[str], docs)
|
60 |
+
split_fn = split_text if self._split else None
|
61 |
+
yield from compute_split_embeddings(
|
62 |
+
docs, PALM_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
|
lilac/embeddings/sbert.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sentence-BERT embeddings. Open-source models, designed to run on device."""
|
2 |
+
from typing import Iterable, cast
|
3 |
+
|
4 |
+
from typing_extensions import override
|
5 |
+
|
6 |
+
from ..schema import Item, RichData
|
7 |
+
from ..signals.signal import TextEmbeddingSignal
|
8 |
+
from ..signals.splitters.chunk_splitter import split_text
|
9 |
+
from .embedding import compute_split_embeddings
|
10 |
+
from .transformer_utils import get_model
|
11 |
+
|
12 |
+
# The `all-mpnet-base-v2` model provides the best quality, while `all-MiniLM-L6-v2`` is 5 times
|
13 |
+
# faster and still offers good quality. See https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/
|
14 |
+
MINI_LM_MODEL = 'all-MiniLM-L6-v2'
|
15 |
+
|
16 |
+
# Maps a tuple of model name and device to the optimal batch size, found empirically.
|
17 |
+
_OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
|
18 |
+
MINI_LM_MODEL: {
|
19 |
+
'': 64, # Default batch size.
|
20 |
+
'mps': 256,
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class SBERT(TextEmbeddingSignal):
|
26 |
+
"""Computes embeddings using Sentence-BERT library."""
|
27 |
+
|
28 |
+
name = 'sbert'
|
29 |
+
display_name = 'SBERT Embeddings'
|
30 |
+
|
31 |
+
@override
|
32 |
+
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
|
33 |
+
"""Call the embedding function."""
|
34 |
+
batch_size, model = get_model(MINI_LM_MODEL, _OPTIMAL_BATCH_SIZES[MINI_LM_MODEL])
|
35 |
+
embed_fn = model.encode
|
36 |
+
split_fn = split_text if self._split else None
|
37 |
+
docs = cast(Iterable[str], docs)
|
38 |
+
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
|
lilac/embeddings/transformer_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for transformer embeddings."""
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import os
|
5 |
+
from typing import TYPE_CHECKING, Optional
|
6 |
+
|
7 |
+
from ..env import data_path
|
8 |
+
from ..utils import log
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
|
13 |
+
|
14 |
+
def get_model(model_name: str,
|
15 |
+
optimal_batch_sizes: dict[str, int] = {}) -> tuple[int, 'SentenceTransformer']:
|
16 |
+
"""Get a transformer model and the optimal batch size for it."""
|
17 |
+
try:
|
18 |
+
import torch.backends.mps
|
19 |
+
from sentence_transformers import SentenceTransformer
|
20 |
+
except ImportError:
|
21 |
+
raise ImportError('Could not import the "sentence_transformers" python package. '
|
22 |
+
'Please install it with `pip install sentence-transformers`.')
|
23 |
+
preferred_device: Optional[str] = None
|
24 |
+
if torch.backends.mps.is_available():
|
25 |
+
preferred_device = 'mps'
|
26 |
+
elif not torch.backends.mps.is_built():
|
27 |
+
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
|
28 |
+
|
29 |
+
@functools.cache
|
30 |
+
def _get_model(model_name: str) -> 'SentenceTransformer':
|
31 |
+
return SentenceTransformer(
|
32 |
+
model_name, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))
|
33 |
+
|
34 |
+
batch_size = optimal_batch_sizes[preferred_device or '']
|
35 |
+
return batch_size, _get_model(model_name)
|
lilac/embeddings/vector_store.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Interface for storing vectors."""
|
2 |
+
|
3 |
+
import abc
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
from typing import Iterable, Optional, Type
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from ..schema import SpanVector, VectorKey
|
11 |
+
from ..utils import open_file
|
12 |
+
|
13 |
+
|
14 |
+
class VectorStore(abc.ABC):
|
15 |
+
"""Interface for storing and retrieving vectors."""
|
16 |
+
|
17 |
+
# The global name of the vector store.
|
18 |
+
name: str
|
19 |
+
|
20 |
+
@abc.abstractmethod
|
21 |
+
def save(self, base_path: str) -> None:
|
22 |
+
"""Save the store to disk."""
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abc.abstractmethod
|
26 |
+
def load(self, base_path: str) -> None:
|
27 |
+
"""Load the store from disk."""
|
28 |
+
pass
|
29 |
+
|
30 |
+
@abc.abstractmethod
|
31 |
+
def size(self) -> int:
|
32 |
+
"""Return the number of vectors in the store."""
|
33 |
+
pass
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
|
37 |
+
"""Add or edit the given keyed embeddings to the store.
|
38 |
+
|
39 |
+
If the keys already exist they will be overwritten, acting as an "upsert".
|
40 |
+
|
41 |
+
Args:
|
42 |
+
keys: The keys to add the embeddings for.
|
43 |
+
embeddings: The embeddings to add. This should be a 2D matrix with the same length as keys.
|
44 |
+
"""
|
45 |
+
pass
|
46 |
+
|
47 |
+
@abc.abstractmethod
|
48 |
+
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
|
49 |
+
"""Return the embeddings for given keys.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
keys: The keys to return the embeddings for. If None, return all embeddings.
|
53 |
+
|
54 |
+
Returns
|
55 |
+
The embeddings for the given keys.
|
56 |
+
"""
|
57 |
+
pass
|
58 |
+
|
59 |
+
def topk(self,
|
60 |
+
query: np.ndarray,
|
61 |
+
k: int,
|
62 |
+
keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
|
63 |
+
"""Return the top k most similar vectors.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
query: The query vector.
|
67 |
+
k: The number of results to return.
|
68 |
+
keys: Optional keys to restrict the search to.
|
69 |
+
|
70 |
+
Returns
|
71 |
+
A list of (key, score) tuples.
|
72 |
+
"""
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
|
76 |
+
PathKey = VectorKey
|
77 |
+
|
78 |
+
_SPANS_PICKLE_NAME = 'spans.pkl'
|
79 |
+
|
80 |
+
|
81 |
+
class VectorDBIndex:
|
82 |
+
"""Stores and retrives span vectors.
|
83 |
+
|
84 |
+
This wraps a regular vector store by adding a mapping from path keys, such as (uuid1, 0),
|
85 |
+
to span keys, such as (uuid1, 0, 0), which denotes the first span in the (uuid1, 0) text document.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, vector_store: str) -> None:
|
89 |
+
self._vector_store: VectorStore = get_vector_store_cls(vector_store)()
|
90 |
+
# Map a path key to spans for that path.
|
91 |
+
self._id_to_spans: dict[PathKey, list[tuple[int, int]]] = {}
|
92 |
+
|
93 |
+
def load(self, base_path: str) -> None:
|
94 |
+
"""Load the vector index from disk."""
|
95 |
+
assert not self._id_to_spans, 'Cannot load into a non-empty index.'
|
96 |
+
with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'rb') as f:
|
97 |
+
self._id_to_spans.update(pickle.load(f))
|
98 |
+
self._vector_store.load(os.path.join(base_path, self._vector_store.name))
|
99 |
+
|
100 |
+
def save(self, base_path: str) -> None:
|
101 |
+
"""Save the vector index to disk."""
|
102 |
+
assert self._id_to_spans, 'Cannot save an empty index.'
|
103 |
+
with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'wb') as f:
|
104 |
+
pickle.dump(list(self._id_to_spans.items()), f)
|
105 |
+
self._vector_store.save(os.path.join(base_path, self._vector_store.name))
|
106 |
+
|
107 |
+
def add(self, spans: list[tuple[PathKey, list[tuple[int, int]]]], embeddings: np.ndarray) -> None:
|
108 |
+
"""Add the given spans and embeddings.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
spans: The spans to initialize the index with.
|
112 |
+
embeddings: The embeddings to initialize the index with.
|
113 |
+
"""
|
114 |
+
assert not self._id_to_spans, 'Cannot add to a non-empty index.'
|
115 |
+
self._id_to_spans.update(spans)
|
116 |
+
vector_keys = [(*path_key, i) for path_key, spans in spans for i in range(len(spans))]
|
117 |
+
assert len(vector_keys) == len(embeddings), (
|
118 |
+
f'Number of spans ({len(vector_keys)}) and embeddings ({len(embeddings)}) must match.')
|
119 |
+
self._vector_store.add(vector_keys, embeddings)
|
120 |
+
|
121 |
+
def get_vector_store(self) -> VectorStore:
|
122 |
+
"""Return the underlying vector store."""
|
123 |
+
return self._vector_store
|
124 |
+
|
125 |
+
def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
|
126 |
+
"""Return the spans with vectors for each key in `keys`.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
keys: The keys to return the vectors for.
|
130 |
+
|
131 |
+
Returns
|
132 |
+
The span vectors for the given keys.
|
133 |
+
"""
|
134 |
+
all_spans: list[list[tuple[int, int]]] = []
|
135 |
+
vector_keys: list[VectorKey] = []
|
136 |
+
for path_key in keys:
|
137 |
+
spans = self._id_to_spans[path_key]
|
138 |
+
all_spans.append(spans)
|
139 |
+
vector_keys.extend([(*path_key, i) for i in range(len(spans))])
|
140 |
+
|
141 |
+
all_vectors = self._vector_store.get(vector_keys)
|
142 |
+
offset = 0
|
143 |
+
for spans in all_spans:
|
144 |
+
vectors = all_vectors[offset:offset + len(spans)]
|
145 |
+
yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
|
146 |
+
offset += len(spans)
|
147 |
+
|
148 |
+
def topk(self,
|
149 |
+
query: np.ndarray,
|
150 |
+
k: int,
|
151 |
+
path_keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, float]]:
|
152 |
+
"""Return the top k most similar vectors.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
query: The query vector.
|
156 |
+
k: The number of results to return.
|
157 |
+
path_keys: Optional key prefixes to restrict the search to.
|
158 |
+
|
159 |
+
Returns
|
160 |
+
A list of (key, score) tuples.
|
161 |
+
"""
|
162 |
+
span_keys: Optional[list[VectorKey]] = None
|
163 |
+
if path_keys is not None:
|
164 |
+
span_keys = [
|
165 |
+
(*path_key, i) for path_key in path_keys for i in range(len(self._id_to_spans[path_key]))
|
166 |
+
]
|
167 |
+
span_k = k
|
168 |
+
path_key_scores: dict[PathKey, float] = {}
|
169 |
+
total_num_span_keys = self._vector_store.size()
|
170 |
+
while (len(path_key_scores) < k and span_k < total_num_span_keys and
|
171 |
+
(not span_keys or span_k < len(span_keys))):
|
172 |
+
span_k += k
|
173 |
+
vector_key_scores = self._vector_store.topk(query, span_k, span_keys)
|
174 |
+
for (*path_key_list, _), score in vector_key_scores:
|
175 |
+
path_key = tuple(path_key_list)
|
176 |
+
if path_key not in path_key_scores:
|
177 |
+
path_key_scores[path_key] = score
|
178 |
+
|
179 |
+
return list(path_key_scores.items())[:k]
|
180 |
+
|
181 |
+
|
182 |
+
VECTOR_STORE_REGISTRY: dict[str, Type[VectorStore]] = {}
|
183 |
+
|
184 |
+
|
185 |
+
def register_vector_store(vector_store_cls: Type[VectorStore]) -> None:
|
186 |
+
"""Register a vector store in the global registry."""
|
187 |
+
if vector_store_cls.name in VECTOR_STORE_REGISTRY:
|
188 |
+
raise ValueError(f'Vector store "{vector_store_cls.name}" has already been registered!')
|
189 |
+
|
190 |
+
VECTOR_STORE_REGISTRY[vector_store_cls.name] = vector_store_cls
|
191 |
+
|
192 |
+
|
193 |
+
def get_vector_store_cls(vector_store_name: str) -> Type[VectorStore]:
|
194 |
+
"""Return a registered vector store given the name in the registry."""
|
195 |
+
return VECTOR_STORE_REGISTRY[vector_store_name]
|
196 |
+
|
197 |
+
|
198 |
+
def clear_vector_store_registry() -> None:
|
199 |
+
"""Clear the vector store registry."""
|
200 |
+
VECTOR_STORE_REGISTRY.clear()
|
lilac/embeddings/vector_store_hnsw.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""HNSW vector store."""
|
2 |
+
|
3 |
+
import multiprocessing
|
4 |
+
from typing import Iterable, Optional, Set, cast
|
5 |
+
|
6 |
+
import hnswlib
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from typing_extensions import override
|
10 |
+
|
11 |
+
from ..schema import VectorKey
|
12 |
+
from ..utils import DebugTimer
|
13 |
+
from .vector_store import VectorStore
|
14 |
+
|
15 |
+
_HNSW_SUFFIX = '.hnswlib.bin'
|
16 |
+
_LOOKUP_SUFFIX = '.lookup.pkl'
|
17 |
+
|
18 |
+
|
19 |
+
class HNSWVectorStore(VectorStore):
|
20 |
+
"""HNSW-backed vector store."""
|
21 |
+
|
22 |
+
name = 'hnsw'
|
23 |
+
|
24 |
+
def __init__(self) -> None:
|
25 |
+
# Maps a `VectorKey` to a row index in `_embeddings`.
|
26 |
+
self._key_to_label: Optional[pd.Series] = None
|
27 |
+
self._index: Optional[hnswlib.Index] = None
|
28 |
+
|
29 |
+
@override
|
30 |
+
def save(self, base_path: str) -> None:
|
31 |
+
assert self._key_to_label is not None and self._index is not None, (
|
32 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
33 |
+
self._index.save_index(base_path + _HNSW_SUFFIX)
|
34 |
+
self._key_to_label.to_pickle(base_path + _LOOKUP_SUFFIX)
|
35 |
+
|
36 |
+
@override
|
37 |
+
def load(self, base_path: str) -> None:
|
38 |
+
self._key_to_label = pd.read_pickle(base_path + _LOOKUP_SUFFIX)
|
39 |
+
dim = int(self._key_to_label.name)
|
40 |
+
index = hnswlib.Index(space='ip', dim=dim)
|
41 |
+
index.set_ef(10)
|
42 |
+
index.set_num_threads(multiprocessing.cpu_count())
|
43 |
+
index.load_index(base_path + _HNSW_SUFFIX)
|
44 |
+
self._index = index
|
45 |
+
|
46 |
+
@override
|
47 |
+
def size(self) -> int:
|
48 |
+
assert self._index is not None, (
|
49 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
50 |
+
return self._index.get_current_count()
|
51 |
+
|
52 |
+
@override
|
53 |
+
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
|
54 |
+
assert self._index is None, (
|
55 |
+
'Embeddings already exist in this store. Upsert is not yet supported.')
|
56 |
+
|
57 |
+
if len(keys) != embeddings.shape[0]:
|
58 |
+
raise ValueError(
|
59 |
+
f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')
|
60 |
+
|
61 |
+
dim = embeddings.shape[1]
|
62 |
+
with DebugTimer('hnswlib index creation'):
|
63 |
+
index = hnswlib.Index(space='ip', dim=dim)
|
64 |
+
index.set_ef(10)
|
65 |
+
index.set_num_threads(multiprocessing.cpu_count())
|
66 |
+
index.init_index(max_elements=len(keys), ef_construction=50, M=16)
|
67 |
+
|
68 |
+
# Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
|
69 |
+
# than float64.
|
70 |
+
embeddings = embeddings.astype(np.float32)
|
71 |
+
row_indices = np.arange(len(keys), dtype=np.int32)
|
72 |
+
self._key_to_label = pd.Series(row_indices, index=keys, dtype=np.int32)
|
73 |
+
self._key_to_label.name = str(dim)
|
74 |
+
index.add_items(embeddings, row_indices)
|
75 |
+
self._index = index
|
76 |
+
|
77 |
+
@override
|
78 |
+
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
|
79 |
+
assert self._index is not None and self._key_to_label is not None, (
|
80 |
+
'No embeddings exist in this store.')
|
81 |
+
if not keys:
|
82 |
+
return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32)
|
83 |
+
locs = self._key_to_label.loc[cast(list[str], keys)].values
|
84 |
+
return np.array(self._index.get_items(locs), dtype=np.float32)
|
85 |
+
|
86 |
+
@override
|
87 |
+
def topk(self,
|
88 |
+
query: np.ndarray,
|
89 |
+
k: int,
|
90 |
+
keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
|
91 |
+
assert self._index is not None and self._key_to_label is not None, (
|
92 |
+
'No embeddings exist in this store.')
|
93 |
+
labels: Set[int] = set()
|
94 |
+
if keys is not None:
|
95 |
+
labels = set(self._key_to_label.loc[cast(list[str], keys)].tolist())
|
96 |
+
k = min(k, len(labels))
|
97 |
+
|
98 |
+
def filter_func(label: int) -> bool:
|
99 |
+
return label in labels
|
100 |
+
|
101 |
+
query = np.expand_dims(query.astype(np.float32), axis=0)
|
102 |
+
locs, dists = self._index.knn_query(query, k=k, filter=filter_func if labels else None)
|
103 |
+
locs = locs[0]
|
104 |
+
dists = dists[0]
|
105 |
+
topk_keys = self._key_to_label.index.values[locs]
|
106 |
+
return [(key, 1 - dist) for key, dist in zip(topk_keys, dists)]
|
lilac/embeddings/vector_store_numpy.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""NumpyVectorStore class for storing vectors in numpy arrays."""
|
2 |
+
|
3 |
+
from typing import Iterable, Optional, cast
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from typing_extensions import override
|
8 |
+
|
9 |
+
from ..schema import VectorKey
|
10 |
+
from .vector_store import VectorStore
|
11 |
+
|
12 |
+
_EMBEDDINGS_SUFFIX = '.matrix.npy'
|
13 |
+
_LOOKUP_SUFFIX = '.lookup.pkl'
|
14 |
+
|
15 |
+
|
16 |
+
class NumpyVectorStore(VectorStore):
|
17 |
+
"""Stores vectors as in-memory np arrays."""
|
18 |
+
name = 'numpy'
|
19 |
+
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self._embeddings: Optional[np.ndarray] = None
|
22 |
+
# Maps a `VectorKey` to a row index in `_embeddings`.
|
23 |
+
self._key_to_index: Optional[pd.Series] = None
|
24 |
+
|
25 |
+
@override
|
26 |
+
def size(self) -> int:
|
27 |
+
assert self._embeddings is not None, (
|
28 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
29 |
+
return len(self._embeddings)
|
30 |
+
|
31 |
+
@override
|
32 |
+
def save(self, base_path: str) -> None:
|
33 |
+
assert self._embeddings is not None and self._key_to_index is not None, (
|
34 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
35 |
+
np.save(base_path + _EMBEDDINGS_SUFFIX, self._embeddings, allow_pickle=False)
|
36 |
+
self._key_to_index.to_pickle(base_path + _LOOKUP_SUFFIX)
|
37 |
+
|
38 |
+
@override
|
39 |
+
def load(self, base_path: str) -> None:
|
40 |
+
self._embeddings = np.load(base_path + _EMBEDDINGS_SUFFIX, allow_pickle=False)
|
41 |
+
self._key_to_index = pd.read_pickle(base_path + _LOOKUP_SUFFIX)
|
42 |
+
|
43 |
+
@override
|
44 |
+
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
|
45 |
+
if self._embeddings or self._key_to_index:
|
46 |
+
raise ValueError('Embeddings already exist in this store. Upsert is not yet supported.')
|
47 |
+
|
48 |
+
if len(keys) != embeddings.shape[0]:
|
49 |
+
raise ValueError(
|
50 |
+
f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')
|
51 |
+
|
52 |
+
# Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
|
53 |
+
# than float64.
|
54 |
+
self._embeddings = embeddings.astype(np.float32)
|
55 |
+
row_indices = np.arange(len(embeddings), dtype=np.uint32)
|
56 |
+
self._key_to_index = pd.Series(row_indices, index=keys, dtype=np.uint32)
|
57 |
+
|
58 |
+
@override
|
59 |
+
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
|
60 |
+
assert self._embeddings is not None and self._key_to_index is not None, (
|
61 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
62 |
+
if not keys:
|
63 |
+
return self._embeddings
|
64 |
+
locs = self._key_to_index.loc[cast(list[str], keys)]
|
65 |
+
return self._embeddings.take(locs, axis=0)
|
66 |
+
|
67 |
+
@override
|
68 |
+
def topk(self,
|
69 |
+
query: np.ndarray,
|
70 |
+
k: int,
|
71 |
+
keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
|
72 |
+
assert self._embeddings is not None and self._key_to_index is not None, (
|
73 |
+
'The vector store has no embeddings. Call load() or add() first.')
|
74 |
+
if keys is not None:
|
75 |
+
row_indices = self._key_to_index.loc[cast(list[str], keys)]
|
76 |
+
embeddings = self._embeddings.take(row_indices, axis=0)
|
77 |
+
keys = list(keys)
|
78 |
+
else:
|
79 |
+
keys, embeddings = cast(list[VectorKey], self._key_to_index.index.tolist()), self._embeddings
|
80 |
+
|
81 |
+
query = query.astype(embeddings.dtype)
|
82 |
+
similarities: np.ndarray = np.dot(embeddings, query).reshape(-1)
|
83 |
+
k = min(k, len(similarities))
|
84 |
+
|
85 |
+
# We do a partition + sort only top K to save time: O(n + klogk) instead of O(nlogn).
|
86 |
+
indices = np.argpartition(similarities, -k)[-k:]
|
87 |
+
# Indices sorted by value from largest to smallest.
|
88 |
+
indices = indices[np.argsort(similarities[indices])][::-1]
|
89 |
+
|
90 |
+
topk_similarities = similarities[indices]
|
91 |
+
topk_keys = [keys[idx] for idx in indices]
|
92 |
+
return list(zip(topk_keys, topk_similarities))
|
lilac/env.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Load environment variables from .env file."""
|
2 |
+
import os
|
3 |
+
from typing import Any, Literal, Optional, Union, cast
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
EnvironmentKeys = Union[Literal['LILAC_DATA_PATH'],
|
8 |
+
# Authentication on the demo.
|
9 |
+
Literal['LILAC_AUTH_ENABLED'], Literal['GOOGLE_CLIENT_ID'],
|
10 |
+
Literal['GOOGLE_CLIENT_SECRET'], Literal['LILAC_OAUTH_SECRET_KEY'],
|
11 |
+
# DuckDB accessing GCS.
|
12 |
+
Literal['GCS_REGION'], Literal['GCS_ACCESS_KEY'], Literal['GCS_SECRET_KEY'],
|
13 |
+
# Embedding API keys.
|
14 |
+
Literal['OPENAI_API_KEY'], Literal['COHERE_API_KEY'],
|
15 |
+
Literal['PALM_API_KEY'],
|
16 |
+
# HuggingFace demos.
|
17 |
+
Literal['HF_USERNAME'], Literal['HF_STAGING_DEMO_REPO'],
|
18 |
+
Literal['SPACE_ID'], Literal['HF_ACCESS_TOKEN'],
|
19 |
+
# DuckDB
|
20 |
+
Literal['DUCKDB_USE_VIEWS'],
|
21 |
+
# Debugging
|
22 |
+
Literal['DEBUG'], Literal['DISABLE_LOGS']]
|
23 |
+
|
24 |
+
|
25 |
+
def _init_env() -> None:
|
26 |
+
in_test = os.environ.get('LILAC_TEST', None)
|
27 |
+
# Load the .env files into the environment in order of highest to lowest priority.
|
28 |
+
|
29 |
+
if not in_test: # Skip local environment variables when testing.
|
30 |
+
load_dotenv('.env.local')
|
31 |
+
load_dotenv('.env.demo')
|
32 |
+
load_dotenv('.env')
|
33 |
+
|
34 |
+
if os.environ.get('LILAC_AUTH_ENABLED', None):
|
35 |
+
if not os.environ.get('GOOGLE_CLIENT_ID', None) or not os.environ.get(
|
36 |
+
'GOOGLE_CLIENT_SECRET', None):
|
37 |
+
raise ValueError(
|
38 |
+
'Missing `GOOGLE_CLIENT_ID` or `GOOGLE_CLIENT_SECRET` when `LILAC_AUTH_ENABLED=true`')
|
39 |
+
SECRET_KEY = os.environ.get('LILAC_OAUTH_SECRET_KEY', None)
|
40 |
+
if not SECRET_KEY:
|
41 |
+
raise ValueError('Missing `LILAC_OAUTH_SECRET_KEY` when `LILAC_AUTH_ENABLED=true`')
|
42 |
+
if os.environ.get('LILAC_AUTH_ENABLED', None):
|
43 |
+
if not os.environ.get('GOOGLE_CLIENT_ID', None) or not os.environ.get(
|
44 |
+
'GOOGLE_CLIENT_SECRET', None):
|
45 |
+
raise ValueError(
|
46 |
+
'Missing `GOOGLE_CLIENT_ID` or `GOOGLE_CLIENT_SECRET` when `LILAC_AUTH_ENABLED=true`')
|
47 |
+
SECRET_KEY = os.environ.get('LILAC_OAUTH_SECRET_KEY', None)
|
48 |
+
if not SECRET_KEY:
|
49 |
+
raise ValueError('Missing `LILAC_OAUTH_SECRET_KEY` when `LILAC_AUTH_ENABLED=true`')
|
50 |
+
|
51 |
+
|
52 |
+
def env(key: EnvironmentKeys, default: Optional[Any] = None) -> Any:
|
53 |
+
"""Return the value of an environment variable."""
|
54 |
+
return os.environ.get(key, default)
|
55 |
+
|
56 |
+
|
57 |
+
def data_path() -> str:
|
58 |
+
"""Return the base path for data."""
|
59 |
+
return cast(str, env('LILAC_DATA_PATH', './data'))
|
60 |
+
|
61 |
+
|
62 |
+
# Initialize the environment at import time.
|
63 |
+
_init_env()
|
lilac/load.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A script to load a dataset or set of datasets from a config for a Lilac instance.
|
2 |
+
|
3 |
+
Usage:
|
4 |
+
|
5 |
+
poetry run python -m lilac.load \
|
6 |
+
--output_dir=demo_data \
|
7 |
+
--config_path=demo.yml
|
8 |
+
"""
|
9 |
+
|
10 |
+
import gc
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
import pathlib
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
import click
|
17 |
+
import dask
|
18 |
+
import psutil
|
19 |
+
import yaml
|
20 |
+
from distributed import Client
|
21 |
+
|
22 |
+
from .config import Config, EmbeddingConfig, SignalConfig
|
23 |
+
from .data_loader import process_source
|
24 |
+
from .db_manager import get_dataset
|
25 |
+
from .schema import UUID_COLUMN
|
26 |
+
from .tasks import TaskManager, TaskStepId
|
27 |
+
from .utils import DebugTimer, get_datasets_dir, list_datasets
|
28 |
+
|
29 |
+
|
30 |
+
@click.command()
|
31 |
+
@click.option(
|
32 |
+
'--output_dir', required=True, type=str, help='The output directory to write files to.')
|
33 |
+
@click.option(
|
34 |
+
'--config_path',
|
35 |
+
required=True,
|
36 |
+
type=str,
|
37 |
+
help='The path to a json or yml file describing the configuration. '
|
38 |
+
'The file contents should be an instance of `lilac.Config`.')
|
39 |
+
@click.option(
|
40 |
+
'--overwrite',
|
41 |
+
help='When True, runs all all data from scratch, overwriting existing data. When false, only'
|
42 |
+
'load new datasets, embeddings, and signals.',
|
43 |
+
type=bool,
|
44 |
+
is_flag=True,
|
45 |
+
default=False)
|
46 |
+
def load_command(output_dir: str, config_path: str, overwrite: bool) -> None:
|
47 |
+
"""Run the source loader as a binary."""
|
48 |
+
load(output_dir, config_path, overwrite)
|
49 |
+
|
50 |
+
|
51 |
+
def load(output_dir: str, config_path: str, overwrite: bool) -> None:
|
52 |
+
"""Run the source loader as a binary."""
|
53 |
+
old_data_path = os.environ.get('LILAC_DATA_PATH')
|
54 |
+
os.environ['LILAC_DATA_PATH'] = output_dir
|
55 |
+
# Turn off debug logging.
|
56 |
+
del os.environ['DEBUG']
|
57 |
+
|
58 |
+
config_ext = pathlib.Path(config_path).suffix
|
59 |
+
if config_ext in ['.yml', '.yaml']:
|
60 |
+
with open(config_path, 'r') as f:
|
61 |
+
config_dict = yaml.safe_load(f)
|
62 |
+
elif config_ext in ['.json']:
|
63 |
+
with open(config_path, 'r') as f:
|
64 |
+
config_dict = json.load(f)
|
65 |
+
else:
|
66 |
+
raise ValueError(f'Unsupported config file extension: {config_ext}')
|
67 |
+
|
68 |
+
config = Config(**config_dict)
|
69 |
+
|
70 |
+
# Explicitly create a dask client in sync mode.
|
71 |
+
dask.config.set({'distributed.worker.daemon': False})
|
72 |
+
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
73 |
+
task_manager = TaskManager(Client(memory_limit=f'{total_memory_gb} GB'))
|
74 |
+
|
75 |
+
if overwrite:
|
76 |
+
shutil.rmtree(get_datasets_dir(output_dir), ignore_errors=True)
|
77 |
+
|
78 |
+
existing_datasets = [f'{d.namespace}/{d.dataset_name}' for d in list_datasets(output_dir)]
|
79 |
+
|
80 |
+
print()
|
81 |
+
print('*** Load datasets ***')
|
82 |
+
if overwrite:
|
83 |
+
datasets_to_load = config.datasets
|
84 |
+
else:
|
85 |
+
datasets_to_load = [
|
86 |
+
d for d in config.datasets if f'{d.namespace}/{d.name}' not in existing_datasets
|
87 |
+
]
|
88 |
+
skipped_datasets = [
|
89 |
+
d for d in config.datasets if f'{d.namespace}/{d.name}' in existing_datasets
|
90 |
+
]
|
91 |
+
print('Skipping loaded datasets:', ', '.join([d.name for d in skipped_datasets]))
|
92 |
+
|
93 |
+
with DebugTimer(f'Loading datasets: {", ".join([d.name for d in datasets_to_load])}'):
|
94 |
+
for d in datasets_to_load:
|
95 |
+
shutil.rmtree(os.path.join(output_dir, d.name), ignore_errors=True)
|
96 |
+
task_id = task_manager.task_id(f'Load dataset {d.namespace}/{d.name}')
|
97 |
+
task_manager.execute(task_id, process_source, output_dir, d.namespace, d.name, d.source,
|
98 |
+
(task_id, 0))
|
99 |
+
|
100 |
+
task_manager.wait()
|
101 |
+
|
102 |
+
print()
|
103 |
+
total_num_rows = 0
|
104 |
+
for d in datasets_to_load:
|
105 |
+
num_rows = get_dataset(d.namespace, d.name).select_rows([UUID_COLUMN], limit=1).total_num_rows
|
106 |
+
print(f'{d.namespace}/{d.name} loaded with {num_rows:,} rows.')
|
107 |
+
gc.collect()
|
108 |
+
total_num_rows += num_rows
|
109 |
+
|
110 |
+
print(f'Done loading {len(datasets_to_load)} datasets with {total_num_rows:,} rows.')
|
111 |
+
|
112 |
+
print('*** Dataset settings ***')
|
113 |
+
for d in config.datasets:
|
114 |
+
if d.settings:
|
115 |
+
dataset = get_dataset(d.namespace, d.name)
|
116 |
+
dataset.update_settings(d.settings)
|
117 |
+
|
118 |
+
print()
|
119 |
+
print('*** Compute embeddings ***')
|
120 |
+
with DebugTimer('Loading embeddings'):
|
121 |
+
for d in config.datasets:
|
122 |
+
# If embeddings are explicitly set, use only those.
|
123 |
+
embeddings = d.embeddings or []
|
124 |
+
# If embeddings are not explicitly set, use the media paths and preferred embedding from
|
125 |
+
# settings.
|
126 |
+
if not embeddings:
|
127 |
+
if d.settings and d.settings.ui:
|
128 |
+
for path in d.settings.ui.media_paths or []:
|
129 |
+
if d.settings.preferred_embedding:
|
130 |
+
embeddings.append(
|
131 |
+
EmbeddingConfig(path=path, embedding=d.settings.preferred_embedding))
|
132 |
+
print('emb configs', embeddings)
|
133 |
+
for e in embeddings:
|
134 |
+
task_id = task_manager.task_id(f'Compute embedding {e.embedding} on {e.path}')
|
135 |
+
task_manager.execute(task_id, _compute_embedding, d.namespace, d.name, e, output_dir,
|
136 |
+
overwrite, (task_id, 0))
|
137 |
+
task_manager.wait()
|
138 |
+
exit()
|
139 |
+
print()
|
140 |
+
print('*** Compute signals ***')
|
141 |
+
with DebugTimer('Computing signals'):
|
142 |
+
for d in config.datasets:
|
143 |
+
# If signals are explicitly set, use only those.
|
144 |
+
signals = d.signals or []
|
145 |
+
# If signals are not explicitly set, use the media paths and config.signals.
|
146 |
+
if not signals:
|
147 |
+
if d.settings and d.settings.ui:
|
148 |
+
for path in d.settings.ui.media_paths or []:
|
149 |
+
for signal in config.signals or []:
|
150 |
+
signals.append(SignalConfig(path=path, signal=signal))
|
151 |
+
|
152 |
+
for s in signals:
|
153 |
+
task_id = task_manager.task_id(f'Compute signal {s.signal} on {s.path}')
|
154 |
+
task_manager.execute(task_id, _compute_signal, d.namespace, d.name, s, output_dir,
|
155 |
+
overwrite, (task_id, 0))
|
156 |
+
task_manager.wait()
|
157 |
+
|
158 |
+
print()
|
159 |
+
print('Done!')
|
160 |
+
|
161 |
+
if old_data_path:
|
162 |
+
os.environ['LILAC_DATA_PATH'] = old_data_path
|
163 |
+
|
164 |
+
|
165 |
+
def _compute_signal(namespace: str, name: str, signal_config: SignalConfig, output_dir: str,
|
166 |
+
overwrite: bool, task_step_id: TaskStepId) -> None:
|
167 |
+
os.environ['LILAC_DATA_PATH'] = output_dir
|
168 |
+
# Turn off debug logging.
|
169 |
+
if 'DEBUG' in os.environ:
|
170 |
+
del os.environ['DEBUG']
|
171 |
+
|
172 |
+
compute_signal = False
|
173 |
+
if overwrite:
|
174 |
+
compute_signal = True
|
175 |
+
|
176 |
+
dataset = get_dataset(namespace, name)
|
177 |
+
|
178 |
+
if not compute_signal:
|
179 |
+
field = dataset.manifest().data_schema.get_field(signal_config.path)
|
180 |
+
signal_field = (field.fields or {}).get(signal_config.signal.key())
|
181 |
+
if not signal_field or signal_field.signal != signal_config.signal.dict():
|
182 |
+
compute_signal = True
|
183 |
+
if compute_signal:
|
184 |
+
dataset.compute_signal(signal_config.signal, signal_config.path, task_step_id)
|
185 |
+
|
186 |
+
gc.collect()
|
187 |
+
|
188 |
+
|
189 |
+
def _compute_embedding(namespace: str, name: str, embedding_config: EmbeddingConfig,
|
190 |
+
output_dir: str, overwrite: bool, task_step_id: TaskStepId) -> None:
|
191 |
+
os.environ['LILAC_DATA_PATH'] = output_dir
|
192 |
+
# Turn off debug logging.
|
193 |
+
if 'DEBUG' in os.environ:
|
194 |
+
del os.environ['DEBUG']
|
195 |
+
|
196 |
+
compute_embedding = False
|
197 |
+
if overwrite:
|
198 |
+
compute_embedding = True
|
199 |
+
|
200 |
+
dataset = get_dataset(namespace, name)
|
201 |
+
if not compute_embedding:
|
202 |
+
field = dataset.manifest().data_schema.get_field(embedding_config.path)
|
203 |
+
embedding_field = (field.fields or {}).get(embedding_config.embedding)
|
204 |
+
if not embedding_field:
|
205 |
+
compute_embedding = True
|
206 |
+
|
207 |
+
if compute_embedding:
|
208 |
+
dataset.compute_embedding(embedding_config.embedding, embedding_config.path, task_step_id)
|
209 |
+
|
210 |
+
gc.collect()
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
load()
|
lilac/make_openapi.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Writes the openapi.json file to the specified output.
|
2 |
+
|
3 |
+
This is meant to run as a standalone script. It lives in lilac/ so we can import the FastAPI app.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
|
7 |
+
import click
|
8 |
+
from fastapi.openapi.utils import get_openapi
|
9 |
+
|
10 |
+
from .server import app
|
11 |
+
|
12 |
+
|
13 |
+
@click.command()
|
14 |
+
@click.option(
|
15 |
+
'--output', required=True, type=str, help='The output filepath for the opepnapi.json file.')
|
16 |
+
def main(output: str) -> None:
|
17 |
+
"""Create the openapi.json file for the API to generate TypeScript stubs."""
|
18 |
+
with open(output, 'w') as f:
|
19 |
+
json.dump(
|
20 |
+
get_openapi(
|
21 |
+
title=app.title,
|
22 |
+
version=app.version,
|
23 |
+
openapi_version=app.openapi_version,
|
24 |
+
description=app.description,
|
25 |
+
routes=app.routes), f)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
main()
|
lilac/parquet_writer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A Parquet file writer that wraps the pyarrow writer."""
|
2 |
+
from typing import IO, Optional
|
3 |
+
|
4 |
+
import pyarrow as pa
|
5 |
+
import pyarrow.parquet as pq
|
6 |
+
|
7 |
+
from .schema import Item, Schema, schema_to_arrow_schema
|
8 |
+
|
9 |
+
|
10 |
+
class ParquetWriter:
|
11 |
+
"""A writer to parquet."""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
schema: Schema,
|
15 |
+
codec: str = 'snappy',
|
16 |
+
row_group_buffer_size: int = 128 * 1024 * 1024,
|
17 |
+
record_batch_size: int = 10_000):
|
18 |
+
self._schema = schema_to_arrow_schema(schema)
|
19 |
+
self._codec = codec
|
20 |
+
self._row_group_buffer_size = row_group_buffer_size
|
21 |
+
self._buffer: list[list[Optional[Item]]] = [[] for _ in range(len(self._schema.names))]
|
22 |
+
self._buffer_size = record_batch_size
|
23 |
+
self._record_batches: list[pa.RecordBatch] = []
|
24 |
+
self._record_batches_byte_size = 0
|
25 |
+
self.writer: pq.ParquetWriter = None
|
26 |
+
|
27 |
+
def open(self, file_handle: IO) -> None:
|
28 |
+
"""Open the destination file for writing."""
|
29 |
+
self.writer = pq.ParquetWriter(file_handle, self._schema, compression=self._codec)
|
30 |
+
|
31 |
+
def write(self, record: Item) -> None:
|
32 |
+
"""Write the record to the destination file."""
|
33 |
+
if len(self._buffer[0]) >= self._buffer_size:
|
34 |
+
self._flush_buffer()
|
35 |
+
|
36 |
+
if self._record_batches_byte_size >= self._row_group_buffer_size:
|
37 |
+
self._write_batches()
|
38 |
+
|
39 |
+
# reorder the data in columnar format.
|
40 |
+
for i, n in enumerate(self._schema.names):
|
41 |
+
self._buffer[i].append(record.get(n))
|
42 |
+
|
43 |
+
def close(self) -> None:
|
44 |
+
"""Flushes the write buffer and closes the destination file."""
|
45 |
+
if len(self._buffer[0]) > 0:
|
46 |
+
self._flush_buffer()
|
47 |
+
if self._record_batches_byte_size > 0:
|
48 |
+
self._write_batches()
|
49 |
+
|
50 |
+
self.writer.close()
|
51 |
+
|
52 |
+
def _write_batches(self) -> None:
|
53 |
+
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
|
54 |
+
self._record_batches = []
|
55 |
+
self._record_batches_byte_size = 0
|
56 |
+
self.writer.write_table(table)
|
57 |
+
|
58 |
+
def _flush_buffer(self) -> None:
|
59 |
+
arrays: list[pa.array] = [[] for _ in range(len(self._schema.names))]
|
60 |
+
for x, y in enumerate(self._buffer):
|
61 |
+
arrays[x] = pa.array(y, type=self._schema.types[x])
|
62 |
+
self._buffer[x] = []
|
63 |
+
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
|
64 |
+
self._record_batches.append(rb)
|
65 |
+
size = 0
|
66 |
+
for x in arrays:
|
67 |
+
for b in x.buffers(): # type: ignore
|
68 |
+
if b is not None:
|
69 |
+
size = size + b.size
|
70 |
+
self._record_batches_byte_size = self._record_batches_byte_size + size
|
lilac/router_concept.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Router for the concept database."""
|
2 |
+
|
3 |
+
from typing import Annotated, Iterable, Optional, cast
|
4 |
+
|
5 |
+
from fastapi import APIRouter, HTTPException
|
6 |
+
from fastapi.params import Depends
|
7 |
+
from openai_function_call import OpenAISchema
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
from .auth import UserInfo, get_session_user
|
11 |
+
from .concepts.concept import DRAFT_MAIN, Concept, ConceptMetrics, DraftId, draft_examples
|
12 |
+
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB, ConceptInfo, ConceptUpdate
|
13 |
+
from .env import env
|
14 |
+
from .router_utils import RouteErrorHandler, server_compute_concept
|
15 |
+
from .schema import RichData, SignalInputType
|
16 |
+
from .signals.concept_scorer import ConceptScoreSignal
|
17 |
+
|
18 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
19 |
+
|
20 |
+
|
21 |
+
@router.get('/', response_model_exclude_none=True)
|
22 |
+
def get_concepts(
|
23 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[ConceptInfo]:
|
24 |
+
"""List the concepts."""
|
25 |
+
return DISK_CONCEPT_DB.list(user)
|
26 |
+
|
27 |
+
|
28 |
+
@router.get('/{namespace}/{concept_name}', response_model_exclude_none=True)
|
29 |
+
def get_concept(namespace: str,
|
30 |
+
concept_name: str,
|
31 |
+
draft: Optional[DraftId] = DRAFT_MAIN,
|
32 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)] = None) -> Concept:
|
33 |
+
"""Get a concept from a database."""
|
34 |
+
concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
|
35 |
+
if not concept:
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=404,
|
38 |
+
detail=f'Concept "{namespace}/{concept_name}" was not found or user does not have access.')
|
39 |
+
|
40 |
+
# Only return the examples from the draft.
|
41 |
+
concept.data = draft_examples(concept, draft or DRAFT_MAIN)
|
42 |
+
|
43 |
+
return concept
|
44 |
+
|
45 |
+
|
46 |
+
class CreateConceptOptions(BaseModel):
|
47 |
+
"""Options for creating a concept."""
|
48 |
+
# Namespace of the concept.
|
49 |
+
namespace: str
|
50 |
+
# Name of the concept.
|
51 |
+
name: str
|
52 |
+
# Input type (modality) of the concept.
|
53 |
+
type: SignalInputType
|
54 |
+
description: Optional[str] = None
|
55 |
+
|
56 |
+
|
57 |
+
@router.post('/create', response_model_exclude_none=True)
|
58 |
+
def create_concept(options: CreateConceptOptions,
|
59 |
+
user: Annotated[Optional[UserInfo],
|
60 |
+
Depends(get_session_user)]) -> Concept:
|
61 |
+
"""Edit a concept in the database."""
|
62 |
+
return DISK_CONCEPT_DB.create(options.namespace, options.name, options.type, options.description,
|
63 |
+
user)
|
64 |
+
|
65 |
+
|
66 |
+
@router.post('/{namespace}/{concept_name}', response_model_exclude_none=True)
|
67 |
+
def edit_concept(namespace: str, concept_name: str, change: ConceptUpdate,
|
68 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> Concept:
|
69 |
+
"""Edit a concept in the database."""
|
70 |
+
return DISK_CONCEPT_DB.edit(namespace, concept_name, change, user)
|
71 |
+
|
72 |
+
|
73 |
+
@router.delete('/{namespace}/{concept_name}')
|
74 |
+
def delete_concept(namespace: str, concept_name: str,
|
75 |
+
user: Annotated[Optional[UserInfo],
|
76 |
+
Depends(get_session_user)]) -> None:
|
77 |
+
"""Deletes the concept from the database."""
|
78 |
+
DISK_CONCEPT_DB.remove(namespace, concept_name, user)
|
79 |
+
|
80 |
+
|
81 |
+
class MergeConceptDraftOptions(BaseModel):
|
82 |
+
"""Merge a draft into main."""
|
83 |
+
draft: DraftId
|
84 |
+
|
85 |
+
|
86 |
+
@router.post('/{namespace}/{concept_name}/merge_draft', response_model_exclude_none=True)
|
87 |
+
def merge_concept_draft(namespace: str, concept_name: str, options: MergeConceptDraftOptions,
|
88 |
+
user: Annotated[Optional[UserInfo],
|
89 |
+
Depends(get_session_user)]) -> Concept:
|
90 |
+
"""Merge a draft in the concept into main."""
|
91 |
+
return DISK_CONCEPT_DB.merge_draft(namespace, concept_name, options.draft, user)
|
92 |
+
|
93 |
+
|
94 |
+
class ScoreExample(BaseModel):
|
95 |
+
"""Example to score along a specific concept."""
|
96 |
+
text: Optional[str] = None
|
97 |
+
img: Optional[bytes] = None
|
98 |
+
|
99 |
+
|
100 |
+
class ScoreBody(BaseModel):
|
101 |
+
"""Request body for the score endpoint."""
|
102 |
+
examples: list[ScoreExample]
|
103 |
+
draft: str = DRAFT_MAIN
|
104 |
+
|
105 |
+
|
106 |
+
class ConceptModelInfo(BaseModel):
|
107 |
+
"""Information about a concept model."""
|
108 |
+
namespace: str
|
109 |
+
concept_name: str
|
110 |
+
embedding_name: str
|
111 |
+
version: int
|
112 |
+
metrics: Optional[ConceptMetrics] = None
|
113 |
+
|
114 |
+
|
115 |
+
@router.get('/{namespace}/{concept_name}/model')
|
116 |
+
def get_concept_models(
|
117 |
+
namespace: str,
|
118 |
+
concept_name: str,
|
119 |
+
user: Annotated[Optional[UserInfo],
|
120 |
+
Depends(get_session_user)] = None) -> list[ConceptModelInfo]:
|
121 |
+
"""Get a concept model from a database."""
|
122 |
+
concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
|
123 |
+
if not concept:
|
124 |
+
raise HTTPException(
|
125 |
+
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
|
126 |
+
models = DISK_CONCEPT_MODEL_DB.get_models(namespace, concept_name, user)
|
127 |
+
|
128 |
+
for m in models:
|
129 |
+
DISK_CONCEPT_MODEL_DB.sync(m, user)
|
130 |
+
|
131 |
+
return [
|
132 |
+
ConceptModelInfo(
|
133 |
+
namespace=m.namespace,
|
134 |
+
concept_name=m.concept_name,
|
135 |
+
embedding_name=m.embedding_name,
|
136 |
+
version=m.version,
|
137 |
+
metrics=m.get_metrics(concept)) for m in models
|
138 |
+
]
|
139 |
+
|
140 |
+
|
141 |
+
@router.get('/{namespace}/{concept_name}/model/{embedding_name}')
|
142 |
+
def get_concept_model(
|
143 |
+
namespace: str,
|
144 |
+
concept_name: str,
|
145 |
+
embedding_name: str,
|
146 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)] = None) -> ConceptModelInfo:
|
147 |
+
"""Get a concept model from a database."""
|
148 |
+
concept = DISK_CONCEPT_DB.get(namespace, concept_name, user)
|
149 |
+
if not concept:
|
150 |
+
raise HTTPException(
|
151 |
+
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
|
152 |
+
|
153 |
+
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name, user=user)
|
154 |
+
if not model:
|
155 |
+
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name, user=user)
|
156 |
+
DISK_CONCEPT_MODEL_DB.sync(model)
|
157 |
+
model_info = ConceptModelInfo(
|
158 |
+
namespace=model.namespace,
|
159 |
+
concept_name=model.concept_name,
|
160 |
+
embedding_name=model.embedding_name,
|
161 |
+
version=model.version,
|
162 |
+
metrics=model.get_metrics(concept))
|
163 |
+
return model_info
|
164 |
+
|
165 |
+
|
166 |
+
@router.post(
|
167 |
+
'/{namespace}/{concept_name}/model/{embedding_name}/score', response_model_exclude_none=True)
|
168 |
+
def score(namespace: str, concept_name: str, embedding_name: str, body: ScoreBody,
|
169 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[list[dict]]:
|
170 |
+
"""Score examples along the specified concept."""
|
171 |
+
concept_scorer = ConceptScoreSignal(
|
172 |
+
namespace=namespace, concept_name=concept_name, embedding=embedding_name)
|
173 |
+
return cast(
|
174 |
+
list[list[dict]],
|
175 |
+
server_compute_concept(concept_scorer, cast(Iterable[RichData],
|
176 |
+
[e.text for e in body.examples]), user))
|
177 |
+
|
178 |
+
|
179 |
+
class Examples(OpenAISchema):
|
180 |
+
"""Generated text examples."""
|
181 |
+
examples: list[str] = Field(..., description='List of generated examples')
|
182 |
+
|
183 |
+
|
184 |
+
@router.get('/generate_examples')
|
185 |
+
def generate_examples(description: str) -> list[str]:
|
186 |
+
"""Generate positive examples for a given concept using an LLM model."""
|
187 |
+
try:
|
188 |
+
import openai
|
189 |
+
except ImportError:
|
190 |
+
raise ImportError('Could not import the "openai" python package. '
|
191 |
+
'Please install it with `pip install openai`.')
|
192 |
+
|
193 |
+
openai.api_key = env('OPENAI_API_KEY')
|
194 |
+
completion = openai.ChatCompletion.create(
|
195 |
+
model='gpt-3.5-turbo-0613',
|
196 |
+
functions=[Examples.openai_schema],
|
197 |
+
messages=[
|
198 |
+
{
|
199 |
+
'role': 'system',
|
200 |
+
'content': 'You must call the `Examples` function with the generated examples',
|
201 |
+
},
|
202 |
+
{
|
203 |
+
'role': 'user',
|
204 |
+
'content': f'Write 5 diverse, unnumbered, and concise examples of "{description}"',
|
205 |
+
},
|
206 |
+
],
|
207 |
+
)
|
208 |
+
result = Examples.from_response(completion)
|
209 |
+
return result.examples
|
lilac/router_data_loader.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""The source loader runner which loads data into parquet files for the app.
|
2 |
+
|
3 |
+
To run the source loader as a binary directly:
|
4 |
+
|
5 |
+
poetry run python -m lilac.datasets.loader \
|
6 |
+
--dataset_name=$DATASET \
|
7 |
+
--output_dir=./data/ \
|
8 |
+
--config_path=./datasets/the_movies_dataset.json
|
9 |
+
"""
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
+
from fastapi import APIRouter, HTTPException, Request
|
13 |
+
from pydantic import BaseModel
|
14 |
+
|
15 |
+
from .auth import get_user_access
|
16 |
+
from .data_loader import process_source
|
17 |
+
from .env import data_path
|
18 |
+
from .router_utils import RouteErrorHandler
|
19 |
+
from .sources.source_registry import get_source_cls, registered_sources
|
20 |
+
from .tasks import TaskId, task_manager
|
21 |
+
|
22 |
+
REQUEST_TIMEOUT_SEC = 30 * 60 # 30 mins.
|
23 |
+
|
24 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
25 |
+
|
26 |
+
|
27 |
+
class ProcessSourceRequest(BaseModel):
|
28 |
+
"""The interface to the /process_source endpoint."""
|
29 |
+
username: str
|
30 |
+
dataset_name: str
|
31 |
+
|
32 |
+
|
33 |
+
class SourcesList(BaseModel):
|
34 |
+
"""The interface to the /process_source endpoint."""
|
35 |
+
sources: list[str]
|
36 |
+
|
37 |
+
|
38 |
+
@router.get('/')
|
39 |
+
def get_sources() -> SourcesList:
|
40 |
+
"""Get the list of available sources."""
|
41 |
+
sources = registered_sources()
|
42 |
+
return SourcesList(sources=list(sources.keys()))
|
43 |
+
|
44 |
+
|
45 |
+
@router.get('/{source_name}')
|
46 |
+
def get_source_schema(source_name: str) -> dict[str, Any]:
|
47 |
+
"""Get the fields for a source."""
|
48 |
+
source_cls = get_source_cls(source_name)
|
49 |
+
return source_cls.schema()
|
50 |
+
|
51 |
+
|
52 |
+
class LoadDatasetOptions(BaseModel):
|
53 |
+
"""Options for loading a dataset."""
|
54 |
+
namespace: str
|
55 |
+
dataset_name: str
|
56 |
+
config: dict[str, Any]
|
57 |
+
|
58 |
+
|
59 |
+
class LoadDatasetResponse(BaseModel):
|
60 |
+
"""Response of the load dataset endpoint."""
|
61 |
+
task_id: TaskId
|
62 |
+
|
63 |
+
|
64 |
+
@router.post('/{source_name}/load')
|
65 |
+
async def load(source_name: str, options: LoadDatasetOptions,
|
66 |
+
request: Request) -> LoadDatasetResponse:
|
67 |
+
"""Load a dataset."""
|
68 |
+
if not get_user_access().create_dataset:
|
69 |
+
raise HTTPException(401, 'User does not have access to load a dataset.')
|
70 |
+
|
71 |
+
source_cls = get_source_cls(source_name)
|
72 |
+
source = source_cls(**options.config)
|
73 |
+
|
74 |
+
task_id = task_manager().task_id(
|
75 |
+
name=f'[{options.namespace}/{options.dataset_name}] Load dataset',
|
76 |
+
description=f'Loader: {source.name}. \n Config: {source}')
|
77 |
+
task_manager().execute(task_id, process_source, data_path(), options.namespace,
|
78 |
+
options.dataset_name, source, (task_id, 0))
|
79 |
+
|
80 |
+
return LoadDatasetResponse(task_id=task_id)
|
lilac/router_dataset.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Router for the dataset database."""
|
2 |
+
from typing import Annotated, Optional, Sequence, Union, cast
|
3 |
+
from urllib.parse import unquote
|
4 |
+
|
5 |
+
from fastapi import APIRouter, HTTPException, Response
|
6 |
+
from fastapi.params import Depends
|
7 |
+
from fastapi.responses import ORJSONResponse
|
8 |
+
from pydantic import BaseModel, validator
|
9 |
+
|
10 |
+
from .auth import UserInfo, get_session_user, get_user_access
|
11 |
+
from .data.dataset import BinaryOp
|
12 |
+
from .data.dataset import Column as DBColumn
|
13 |
+
from .data.dataset import DatasetManifest, DatasetSettings, FeatureListValue, FeatureValue
|
14 |
+
from .data.dataset import Filter as PyFilter
|
15 |
+
from .data.dataset import (
|
16 |
+
GroupsSortBy,
|
17 |
+
ListOp,
|
18 |
+
Search,
|
19 |
+
SelectGroupsResult,
|
20 |
+
SelectRowsSchemaResult,
|
21 |
+
SortOrder,
|
22 |
+
StatsResult,
|
23 |
+
UnaryOp,
|
24 |
+
)
|
25 |
+
from .db_manager import get_dataset, remove_dataset_from_cache
|
26 |
+
from .env import data_path
|
27 |
+
from .router_utils import RouteErrorHandler
|
28 |
+
from .schema import Bin, Path, normalize_path
|
29 |
+
from .signals.concept_labels import ConceptLabelsSignal
|
30 |
+
from .signals.concept_scorer import ConceptScoreSignal
|
31 |
+
from .signals.semantic_similarity import SemanticSimilaritySignal
|
32 |
+
from .signals.signal import Signal, TextEmbeddingSignal, TextSignal, resolve_signal
|
33 |
+
from .signals.substring_search import SubstringSignal
|
34 |
+
from .tasks import TaskId, task_manager
|
35 |
+
from .utils import DatasetInfo, list_datasets
|
36 |
+
|
37 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
38 |
+
|
39 |
+
|
40 |
+
@router.get('/', response_model_exclude_none=True)
|
41 |
+
def get_datasets() -> list[DatasetInfo]:
|
42 |
+
"""List the datasets."""
|
43 |
+
return list_datasets(data_path())
|
44 |
+
|
45 |
+
|
46 |
+
class WebManifest(BaseModel):
|
47 |
+
"""Information about a dataset."""
|
48 |
+
dataset_manifest: DatasetManifest
|
49 |
+
|
50 |
+
|
51 |
+
@router.get('/{namespace}/{dataset_name}')
|
52 |
+
def get_manifest(namespace: str, dataset_name: str) -> WebManifest:
|
53 |
+
"""Get the web manifest for the dataset."""
|
54 |
+
dataset = get_dataset(namespace, dataset_name)
|
55 |
+
res = WebManifest(dataset_manifest=dataset.manifest())
|
56 |
+
# Avoids the error that Signal abstract class is not serializable.
|
57 |
+
return cast(WebManifest, ORJSONResponse(res.dict(exclude_none=True)))
|
58 |
+
|
59 |
+
|
60 |
+
class ComputeSignalOptions(BaseModel):
|
61 |
+
"""The request for the compute signal endpoint."""
|
62 |
+
signal: Signal
|
63 |
+
|
64 |
+
# The leaf path to compute the signal on.
|
65 |
+
leaf_path: Path
|
66 |
+
|
67 |
+
@validator('signal', pre=True)
|
68 |
+
def parse_signal(cls, signal: dict) -> Signal:
|
69 |
+
"""Parse a signal to its specific subclass instance."""
|
70 |
+
return resolve_signal(signal)
|
71 |
+
|
72 |
+
|
73 |
+
@router.delete('/{namespace}/{dataset_name}')
|
74 |
+
def delete_dataset(namespace: str, dataset_name: str) -> None:
|
75 |
+
"""Delete the dataset."""
|
76 |
+
if not get_user_access().dataset.delete_dataset:
|
77 |
+
raise HTTPException(401, 'User does not have access to delete this dataset.')
|
78 |
+
|
79 |
+
dataset = get_dataset(namespace, dataset_name)
|
80 |
+
dataset.delete()
|
81 |
+
remove_dataset_from_cache(namespace, dataset_name)
|
82 |
+
|
83 |
+
|
84 |
+
class ComputeSignalResponse(BaseModel):
|
85 |
+
"""Response of the compute signal column endpoint."""
|
86 |
+
task_id: TaskId
|
87 |
+
|
88 |
+
|
89 |
+
@router.post('/{namespace}/{dataset_name}/compute_signal')
|
90 |
+
def compute_signal(namespace: str, dataset_name: str,
|
91 |
+
options: ComputeSignalOptions) -> ComputeSignalResponse:
|
92 |
+
"""Compute a signal for a dataset."""
|
93 |
+
if not get_user_access().dataset.compute_signals:
|
94 |
+
raise HTTPException(401, 'User does not have access to compute signals over this dataset.')
|
95 |
+
|
96 |
+
def _task_compute_signal(namespace: str, dataset_name: str, options_dict: dict,
|
97 |
+
task_id: TaskId) -> None:
|
98 |
+
# NOTE: We manually call .dict() to avoid the dask serializer, which doesn't call the underlying
|
99 |
+
# pydantic serializer.
|
100 |
+
options = ComputeSignalOptions(**options_dict)
|
101 |
+
dataset = get_dataset(namespace, dataset_name)
|
102 |
+
dataset.compute_signal(options.signal, options.leaf_path, task_step_id=(task_id, 0))
|
103 |
+
|
104 |
+
path_str = '.'.join(map(str, options.leaf_path))
|
105 |
+
task_id = task_manager().task_id(
|
106 |
+
name=f'[{namespace}/{dataset_name}] Compute signal "{options.signal.name}" on "{path_str}"',
|
107 |
+
description=f'Config: {options.signal}')
|
108 |
+
task_manager().execute(task_id, _task_compute_signal, namespace, dataset_name, options.dict(),
|
109 |
+
task_id)
|
110 |
+
|
111 |
+
return ComputeSignalResponse(task_id=task_id)
|
112 |
+
|
113 |
+
|
114 |
+
class DeleteSignalOptions(BaseModel):
|
115 |
+
"""The request for the delete signal endpoint."""
|
116 |
+
# The signal path holding the data from the signal.
|
117 |
+
signal_path: Path
|
118 |
+
|
119 |
+
|
120 |
+
class DeleteSignalResponse(BaseModel):
|
121 |
+
"""Response of the compute signal column endpoint."""
|
122 |
+
completed: bool
|
123 |
+
|
124 |
+
|
125 |
+
@router.delete('/{namespace}/{dataset_name}/delete_signal')
|
126 |
+
def delete_signal(namespace: str, dataset_name: str,
|
127 |
+
options: DeleteSignalOptions) -> DeleteSignalResponse:
|
128 |
+
"""Delete a signal from a dataset."""
|
129 |
+
if not get_user_access().dataset.delete_signals:
|
130 |
+
raise HTTPException(401, 'User does not have access to delete this signal.')
|
131 |
+
|
132 |
+
dataset = get_dataset(namespace, dataset_name)
|
133 |
+
dataset.delete_signal(options.signal_path)
|
134 |
+
return DeleteSignalResponse(completed=True)
|
135 |
+
|
136 |
+
|
137 |
+
class GetStatsOptions(BaseModel):
|
138 |
+
"""The request for the get stats endpoint."""
|
139 |
+
leaf_path: Path
|
140 |
+
|
141 |
+
|
142 |
+
@router.post('/{namespace}/{dataset_name}/stats')
|
143 |
+
def get_stats(namespace: str, dataset_name: str, options: GetStatsOptions) -> StatsResult:
|
144 |
+
"""Get the stats for the dataset."""
|
145 |
+
dataset = get_dataset(namespace, dataset_name)
|
146 |
+
return dataset.stats(options.leaf_path)
|
147 |
+
|
148 |
+
|
149 |
+
class BinaryFilter(BaseModel):
|
150 |
+
"""A filter on a column."""
|
151 |
+
path: Path
|
152 |
+
op: BinaryOp
|
153 |
+
value: FeatureValue
|
154 |
+
|
155 |
+
|
156 |
+
class UnaryFilter(BaseModel):
|
157 |
+
"""A filter on a column."""
|
158 |
+
path: Path
|
159 |
+
op: UnaryOp
|
160 |
+
value: None = None
|
161 |
+
|
162 |
+
|
163 |
+
class ListFilter(BaseModel):
|
164 |
+
"""A filter on a column."""
|
165 |
+
path: Path
|
166 |
+
op: ListOp
|
167 |
+
value: FeatureListValue
|
168 |
+
|
169 |
+
|
170 |
+
Filter = Union[BinaryFilter, UnaryFilter, ListFilter]
|
171 |
+
|
172 |
+
AllSignalTypes = Union[ConceptScoreSignal, ConceptLabelsSignal, SubstringSignal,
|
173 |
+
SemanticSimilaritySignal, TextEmbeddingSignal, TextSignal, Signal]
|
174 |
+
|
175 |
+
|
176 |
+
# We override the `Column` class so we can add explicitly all signal types for better OpenAPI spec.
|
177 |
+
class Column(DBColumn):
|
178 |
+
"""A column in the dataset."""
|
179 |
+
signal_udf: Optional[AllSignalTypes] = None
|
180 |
+
|
181 |
+
|
182 |
+
class SelectRowsOptions(BaseModel):
|
183 |
+
"""The request for the select rows endpoint."""
|
184 |
+
columns: Optional[Sequence[Union[Path, Column]]] = None
|
185 |
+
searches: Optional[Sequence[Search]] = None
|
186 |
+
filters: Optional[Sequence[Filter]] = None
|
187 |
+
sort_by: Optional[Sequence[Path]] = None
|
188 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC
|
189 |
+
limit: Optional[int] = None
|
190 |
+
offset: Optional[int] = None
|
191 |
+
combine_columns: Optional[bool] = None
|
192 |
+
|
193 |
+
|
194 |
+
class SelectRowsSchemaOptions(BaseModel):
|
195 |
+
"""The request for the select rows schema endpoint."""
|
196 |
+
columns: Optional[Sequence[Union[Path, Column]]] = None
|
197 |
+
searches: Optional[Sequence[Search]] = None
|
198 |
+
sort_by: Optional[Sequence[Path]] = None
|
199 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC
|
200 |
+
combine_columns: Optional[bool] = None
|
201 |
+
|
202 |
+
|
203 |
+
class SelectRowsResponse(BaseModel):
|
204 |
+
"""The response for the select rows endpoint."""
|
205 |
+
rows: list[dict]
|
206 |
+
total_num_rows: int
|
207 |
+
|
208 |
+
|
209 |
+
@router.get('/{namespace}/{dataset_name}/select_rows_download', response_model=None)
|
210 |
+
def select_rows_download(
|
211 |
+
namespace: str, dataset_name: str, url_safe_options: str,
|
212 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> list[dict]:
|
213 |
+
"""Select rows from the dataset database and downloads them."""
|
214 |
+
options = SelectRowsOptions.parse_raw(unquote(url_safe_options))
|
215 |
+
return select_rows(namespace, dataset_name, options, user).rows
|
216 |
+
|
217 |
+
|
218 |
+
@router.post('/{namespace}/{dataset_name}/select_rows', response_model_exclude_none=True)
|
219 |
+
def select_rows(
|
220 |
+
namespace: str, dataset_name: str, options: SelectRowsOptions,
|
221 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SelectRowsResponse:
|
222 |
+
"""Select rows from the dataset database."""
|
223 |
+
dataset = get_dataset(namespace, dataset_name)
|
224 |
+
|
225 |
+
sanitized_filters = [
|
226 |
+
PyFilter(path=normalize_path(f.path), op=f.op, value=f.value) for f in (options.filters or [])
|
227 |
+
]
|
228 |
+
|
229 |
+
res = dataset.select_rows(
|
230 |
+
columns=options.columns,
|
231 |
+
searches=options.searches or [],
|
232 |
+
filters=sanitized_filters,
|
233 |
+
sort_by=options.sort_by,
|
234 |
+
sort_order=options.sort_order,
|
235 |
+
limit=options.limit,
|
236 |
+
offset=options.offset,
|
237 |
+
combine_columns=options.combine_columns or False,
|
238 |
+
user=user)
|
239 |
+
|
240 |
+
return SelectRowsResponse(rows=list(res), total_num_rows=res.total_num_rows)
|
241 |
+
|
242 |
+
|
243 |
+
@router.post('/{namespace}/{dataset_name}/select_rows_schema', response_model_exclude_none=True)
|
244 |
+
def select_rows_schema(namespace: str, dataset_name: str,
|
245 |
+
options: SelectRowsSchemaOptions) -> SelectRowsSchemaResult:
|
246 |
+
"""Select rows from the dataset database."""
|
247 |
+
dataset = get_dataset(namespace, dataset_name)
|
248 |
+
return dataset.select_rows_schema(
|
249 |
+
columns=options.columns,
|
250 |
+
searches=options.searches or [],
|
251 |
+
sort_by=options.sort_by,
|
252 |
+
sort_order=options.sort_order,
|
253 |
+
combine_columns=options.combine_columns or False)
|
254 |
+
|
255 |
+
|
256 |
+
class SelectGroupsOptions(BaseModel):
|
257 |
+
"""The request for the select groups endpoint."""
|
258 |
+
leaf_path: Path
|
259 |
+
filters: Optional[Sequence[Filter]] = None
|
260 |
+
sort_by: Optional[GroupsSortBy] = GroupsSortBy.COUNT
|
261 |
+
sort_order: Optional[SortOrder] = SortOrder.DESC
|
262 |
+
limit: Optional[int] = 100
|
263 |
+
bins: Optional[list[Bin]] = None
|
264 |
+
|
265 |
+
|
266 |
+
@router.post('/{namespace}/{dataset_name}/select_groups')
|
267 |
+
def select_groups(namespace: str, dataset_name: str,
|
268 |
+
options: SelectGroupsOptions) -> SelectGroupsResult:
|
269 |
+
"""Select groups from the dataset database."""
|
270 |
+
dataset = get_dataset(namespace, dataset_name)
|
271 |
+
sanitized_filters = [
|
272 |
+
PyFilter(path=normalize_path(f.path), op=f.op, value=f.value) for f in (options.filters or [])
|
273 |
+
]
|
274 |
+
return dataset.select_groups(options.leaf_path, sanitized_filters, options.sort_by,
|
275 |
+
options.sort_order, options.limit, options.bins)
|
276 |
+
|
277 |
+
|
278 |
+
@router.get('/{namespace}/{dataset_name}/media')
|
279 |
+
def get_media(namespace: str, dataset_name: str, item_id: str, leaf_path: str) -> Response:
|
280 |
+
"""Get the media for the dataset."""
|
281 |
+
dataset = get_dataset(namespace, dataset_name)
|
282 |
+
path = tuple(leaf_path.split('.'))
|
283 |
+
result = dataset.media(item_id, path)
|
284 |
+
# Return the response via HTTP.
|
285 |
+
return Response(content=result.data)
|
286 |
+
|
287 |
+
|
288 |
+
@router.get('/{namespace}/{dataset_name}/settings')
|
289 |
+
def get_settings(namespace: str, dataset_name: str) -> DatasetSettings:
|
290 |
+
"""Get the media for the dataset."""
|
291 |
+
dataset = get_dataset(namespace, dataset_name)
|
292 |
+
return dataset.settings()
|
293 |
+
|
294 |
+
|
295 |
+
@router.post('/{namespace}/{dataset_name}/settings', response_model_exclude_none=True)
|
296 |
+
def update_settings(namespace: str, dataset_name: str, settings: DatasetSettings) -> None:
|
297 |
+
"""Get the media for the dataset."""
|
298 |
+
if not get_user_access().dataset.compute_signals:
|
299 |
+
raise HTTPException(401, 'User does not have access to update the settings of this dataset.')
|
300 |
+
|
301 |
+
dataset = get_dataset(namespace, dataset_name)
|
302 |
+
dataset.update_settings(settings)
|
303 |
+
return None
|
lilac/router_google_login.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Router for Google OAuth2 login."""
|
2 |
+
|
3 |
+
from urllib.parse import urlparse, urlunparse
|
4 |
+
|
5 |
+
from authlib.integrations.starlette_client import OAuth, OAuthError
|
6 |
+
from fastapi import APIRouter, Request, Response
|
7 |
+
from fastapi.responses import HTMLResponse
|
8 |
+
from starlette.config import Config
|
9 |
+
from starlette.responses import RedirectResponse
|
10 |
+
|
11 |
+
from .auth import UserInfo
|
12 |
+
from .env import env
|
13 |
+
from .router_utils import RouteErrorHandler
|
14 |
+
|
15 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
16 |
+
|
17 |
+
if env('LILAC_AUTH_ENABLED'):
|
18 |
+
oauth = OAuth(
|
19 |
+
Config(
|
20 |
+
environ={
|
21 |
+
'GOOGLE_CLIENT_ID': env('GOOGLE_CLIENT_ID'),
|
22 |
+
'GOOGLE_CLIENT_SECRET': env('GOOGLE_CLIENT_SECRET')
|
23 |
+
}))
|
24 |
+
oauth.register(
|
25 |
+
name='google',
|
26 |
+
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
|
27 |
+
client_kwargs={'scope': 'openid email profile'},
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
@router.get('/login')
|
32 |
+
async def login(request: Request, origin_url: str) -> RedirectResponse:
|
33 |
+
"""Redirects to Google OAuth login page."""
|
34 |
+
auth_path = urlunparse(urlparse(origin_url)._replace(path='/google/auth'))
|
35 |
+
return await oauth.google.authorize_redirect(request, auth_path)
|
36 |
+
|
37 |
+
|
38 |
+
@router.get('/auth')
|
39 |
+
async def auth(request: Request) -> Response:
|
40 |
+
"""Handles the Google OAuth callback."""
|
41 |
+
try:
|
42 |
+
token = await oauth.google.authorize_access_token(request)
|
43 |
+
except OAuthError as error:
|
44 |
+
return HTMLResponse(f'<h1>{error}</h1>')
|
45 |
+
userinfo = token['userinfo']
|
46 |
+
request.session['user'] = UserInfo(
|
47 |
+
id=str(userinfo['sub']),
|
48 |
+
email=userinfo['email'],
|
49 |
+
name=userinfo['name'],
|
50 |
+
given_name=userinfo['given_name'],
|
51 |
+
family_name=userinfo['family_name']).dict()
|
52 |
+
|
53 |
+
return RedirectResponse(url='/')
|
54 |
+
|
55 |
+
|
56 |
+
@router.get('/logout')
|
57 |
+
def logout(request: Request) -> RedirectResponse:
|
58 |
+
"""Logs the user out."""
|
59 |
+
request.session.pop('user', None)
|
60 |
+
return RedirectResponse(url='/')
|
lilac/router_signal.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Router for the signal registry."""
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Annotated, Any, Optional
|
5 |
+
|
6 |
+
from fastapi import APIRouter, Depends
|
7 |
+
from pydantic import BaseModel, validator
|
8 |
+
|
9 |
+
from .auth import UserInfo, get_session_user
|
10 |
+
from .router_utils import RouteErrorHandler, server_compute_concept
|
11 |
+
from .schema import Field, SignalInputType
|
12 |
+
from .signals.concept_scorer import ConceptScoreSignal
|
13 |
+
from .signals.signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
|
14 |
+
|
15 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
16 |
+
|
17 |
+
EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']
|
18 |
+
|
19 |
+
|
20 |
+
class SignalInfo(BaseModel):
|
21 |
+
"""Information about a signal."""
|
22 |
+
name: str
|
23 |
+
input_type: SignalInputType
|
24 |
+
json_schema: dict[str, Any]
|
25 |
+
|
26 |
+
|
27 |
+
@router.get('/', response_model_exclude_none=True)
|
28 |
+
def get_signals() -> list[SignalInfo]:
|
29 |
+
"""List the signals."""
|
30 |
+
return [
|
31 |
+
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
|
32 |
+
for s in SIGNAL_REGISTRY.values()
|
33 |
+
if not issubclass(s, TextEmbeddingSignal)
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
@router.get('/embeddings', response_model_exclude_none=True)
|
38 |
+
def get_embeddings() -> list[SignalInfo]:
|
39 |
+
"""List the embeddings."""
|
40 |
+
embedding_infos = [
|
41 |
+
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
|
42 |
+
for s in SIGNAL_REGISTRY.values()
|
43 |
+
if issubclass(s, TextEmbeddingSignal)
|
44 |
+
]
|
45 |
+
|
46 |
+
# Sort the embedding infos by priority.
|
47 |
+
embedding_infos = sorted(
|
48 |
+
embedding_infos,
|
49 |
+
key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
|
50 |
+
if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)
|
51 |
+
|
52 |
+
return embedding_infos
|
53 |
+
|
54 |
+
|
55 |
+
class SignalComputeOptions(BaseModel):
|
56 |
+
"""The request for the standalone compute signal endpoint."""
|
57 |
+
signal: Signal
|
58 |
+
# The inputs to compute.
|
59 |
+
inputs: list[str]
|
60 |
+
|
61 |
+
@validator('signal', pre=True)
|
62 |
+
def parse_signal(cls, signal: dict) -> Signal:
|
63 |
+
"""Parse a signal to its specific subclass instance."""
|
64 |
+
return resolve_signal(signal)
|
65 |
+
|
66 |
+
|
67 |
+
class SignalComputeResponse(BaseModel):
|
68 |
+
"""The response for the standalone compute signal endpoint."""
|
69 |
+
items: list[Optional[Any]]
|
70 |
+
|
71 |
+
|
72 |
+
@router.post('/compute', response_model_exclude_none=True)
|
73 |
+
def compute(
|
74 |
+
options: SignalComputeOptions,
|
75 |
+
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
|
76 |
+
"""Compute a signal over a set of inputs."""
|
77 |
+
signal = options.signal
|
78 |
+
if isinstance(signal, ConceptScoreSignal):
|
79 |
+
result = server_compute_concept(signal, options.inputs, user)
|
80 |
+
else:
|
81 |
+
signal.setup()
|
82 |
+
result = list(signal.compute(options.inputs))
|
83 |
+
return SignalComputeResponse(items=result)
|
84 |
+
|
85 |
+
|
86 |
+
class SignalSchemaOptions(BaseModel):
|
87 |
+
"""The request for the signal schema endpoint."""
|
88 |
+
signal: Signal
|
89 |
+
|
90 |
+
@validator('signal', pre=True)
|
91 |
+
def parse_signal(cls, signal: dict) -> Signal:
|
92 |
+
"""Parse a signal to its specific subclass instance."""
|
93 |
+
return resolve_signal(signal)
|
94 |
+
|
95 |
+
|
96 |
+
class SignalSchemaResponse(BaseModel):
|
97 |
+
"""The response for the signal schema endpoint."""
|
98 |
+
fields: Field
|
99 |
+
|
100 |
+
|
101 |
+
@router.post('/schema', response_model_exclude_none=True)
|
102 |
+
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
|
103 |
+
"""Get the schema for a signal."""
|
104 |
+
signal = options.signal
|
105 |
+
return SignalSchemaResponse(fields=signal.fields())
|
lilac/router_tasks.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Router for tasks."""
|
2 |
+
|
3 |
+
from fastapi import APIRouter
|
4 |
+
|
5 |
+
from .router_utils import RouteErrorHandler
|
6 |
+
from .tasks import TaskManifest, task_manager
|
7 |
+
|
8 |
+
router = APIRouter(route_class=RouteErrorHandler)
|
9 |
+
|
10 |
+
|
11 |
+
@router.get('/')
|
12 |
+
async def get_task_manifest() -> TaskManifest:
|
13 |
+
"""Get the tasks, both completed and pending."""
|
14 |
+
return await task_manager().manifest()
|
lilac/router_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for routers."""
|
2 |
+
|
3 |
+
import traceback
|
4 |
+
from typing import Callable, Iterable, Optional
|
5 |
+
|
6 |
+
from fastapi import HTTPException, Request, Response
|
7 |
+
from fastapi.routing import APIRoute
|
8 |
+
|
9 |
+
from .auth import UserInfo
|
10 |
+
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB
|
11 |
+
from .schema import Item, RichData
|
12 |
+
from .signals.concept_scorer import ConceptScoreSignal
|
13 |
+
|
14 |
+
|
15 |
+
class RouteErrorHandler(APIRoute):
|
16 |
+
"""Custom APIRoute that handles application errors and exceptions."""
|
17 |
+
|
18 |
+
def get_route_handler(self) -> Callable:
|
19 |
+
"""Get the route handler."""
|
20 |
+
original_route_handler = super().get_route_handler()
|
21 |
+
|
22 |
+
async def custom_route_handler(request: Request) -> Response:
|
23 |
+
try:
|
24 |
+
return await original_route_handler(request)
|
25 |
+
except Exception as ex:
|
26 |
+
if isinstance(ex, HTTPException):
|
27 |
+
raise ex
|
28 |
+
|
29 |
+
print('Route error:', request.url)
|
30 |
+
print(ex)
|
31 |
+
print(traceback.format_exc())
|
32 |
+
|
33 |
+
# wrap error into pretty 500 exception
|
34 |
+
raise HTTPException(status_code=500, detail=traceback.format_exc()) from ex
|
35 |
+
|
36 |
+
return custom_route_handler
|
37 |
+
|
38 |
+
|
39 |
+
def server_compute_concept(signal: ConceptScoreSignal, examples: Iterable[RichData],
|
40 |
+
user: Optional[UserInfo]) -> list[Optional[Item]]:
|
41 |
+
"""Compute a concept from the REST endpoints."""
|
42 |
+
# TODO(nsthorat): Move this to the setup() method in the concept_scorer.
|
43 |
+
concept = DISK_CONCEPT_DB.get(signal.namespace, signal.concept_name, user)
|
44 |
+
if not concept:
|
45 |
+
raise HTTPException(
|
46 |
+
status_code=404, detail=f'Concept "{signal.namespace}/{signal.concept_name}" was not found')
|
47 |
+
model = DISK_CONCEPT_MODEL_DB.get(
|
48 |
+
signal.namespace, signal.concept_name, signal.embedding, user=user)
|
49 |
+
if model is None:
|
50 |
+
model = DISK_CONCEPT_MODEL_DB.create(
|
51 |
+
signal.namespace, signal.concept_name, signal.embedding, user=user)
|
52 |
+
DISK_CONCEPT_MODEL_DB.sync(model, user)
|
53 |
+
texts = [example or '' for example in examples]
|
54 |
+
return list(signal.compute(texts))
|
lilac/schema.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Item: an individual entry in the dataset."""
|
2 |
+
|
3 |
+
import csv
|
4 |
+
import io
|
5 |
+
from collections import deque
|
6 |
+
from datetime import datetime
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Any, Optional, Union, cast
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pyarrow as pa
|
12 |
+
from pydantic import BaseModel, StrictInt, StrictStr, validator
|
13 |
+
from typing_extensions import TypedDict
|
14 |
+
|
15 |
+
MANIFEST_FILENAME = 'manifest.json'
|
16 |
+
PARQUET_FILENAME_PREFIX = 'data'
|
17 |
+
|
18 |
+
# We choose `__rowid__` inspired by the standard `rowid` pseudocolumn in DBs:
|
19 |
+
# https://docs.oracle.com/cd/B19306_01/server.102/b14200/pseudocolumns008.htm
|
20 |
+
UUID_COLUMN = '__rowid__'
|
21 |
+
PATH_WILDCARD = '*'
|
22 |
+
VALUE_KEY = '__value__'
|
23 |
+
SIGNAL_METADATA_KEY = '__metadata__'
|
24 |
+
TEXT_SPAN_START_FEATURE = 'start'
|
25 |
+
TEXT_SPAN_END_FEATURE = 'end'
|
26 |
+
|
27 |
+
EMBEDDING_KEY = 'embedding'
|
28 |
+
|
29 |
+
# Python doesn't work with recursive types. These types provide some notion of type-safety.
|
30 |
+
Scalar = Union[bool, datetime, int, float, str, bytes]
|
31 |
+
Item = Any
|
32 |
+
|
33 |
+
# Contains a string field name, a wildcard for repeateds, or a specific integer index for repeateds.
|
34 |
+
# This path represents a path to a particular column.
|
35 |
+
# Examples:
|
36 |
+
# ['article', 'field'] represents {'article': {'field': VALUES}}
|
37 |
+
# ['article', '*', 'field'] represents {'article': [{'field': VALUES}, {'field': VALUES}]}
|
38 |
+
# ['article', '0', 'field'] represents {'article': {'field': VALUES}}
|
39 |
+
PathTuple = tuple[StrictStr, ...]
|
40 |
+
Path = Union[PathTuple, StrictStr]
|
41 |
+
|
42 |
+
PathKeyedItem = tuple[Path, Item]
|
43 |
+
|
44 |
+
# These fields are for for python only and not written to a schema.
|
45 |
+
RichData = Union[str, bytes]
|
46 |
+
VectorKey = tuple[Union[StrictStr, StrictInt], ...]
|
47 |
+
PathKey = VectorKey
|
48 |
+
|
49 |
+
|
50 |
+
class DataType(str, Enum):
|
51 |
+
"""Enum holding the dtype for a field."""
|
52 |
+
STRING = 'string'
|
53 |
+
# Contains {start, end} offset integers with a reference_column.
|
54 |
+
STRING_SPAN = 'string_span'
|
55 |
+
BOOLEAN = 'boolean'
|
56 |
+
|
57 |
+
# Ints.
|
58 |
+
INT8 = 'int8'
|
59 |
+
INT16 = 'int16'
|
60 |
+
INT32 = 'int32'
|
61 |
+
INT64 = 'int64'
|
62 |
+
UINT8 = 'uint8'
|
63 |
+
UINT16 = 'uint16'
|
64 |
+
UINT32 = 'uint32'
|
65 |
+
UINT64 = 'uint64'
|
66 |
+
|
67 |
+
# Floats.
|
68 |
+
FLOAT16 = 'float16'
|
69 |
+
FLOAT32 = 'float32'
|
70 |
+
FLOAT64 = 'float64'
|
71 |
+
|
72 |
+
### Time ###
|
73 |
+
# Time of day (no time zone).
|
74 |
+
TIME = 'time'
|
75 |
+
# Calendar date (year, month, day), no time zone.
|
76 |
+
DATE = 'date'
|
77 |
+
# An "Instant" stored as number of microseconds (µs) since 1970-01-01 00:00:00+00 (UTC time zone).
|
78 |
+
TIMESTAMP = 'timestamp'
|
79 |
+
# Time span, stored as microseconds.
|
80 |
+
INTERVAL = 'interval'
|
81 |
+
|
82 |
+
BINARY = 'binary'
|
83 |
+
|
84 |
+
EMBEDDING = 'embedding'
|
85 |
+
|
86 |
+
NULL = 'null'
|
87 |
+
|
88 |
+
def __repr__(self) -> str:
|
89 |
+
return self.value
|
90 |
+
|
91 |
+
|
92 |
+
class SignalInputType(str, Enum):
|
93 |
+
"""Enum holding the signal input type."""
|
94 |
+
TEXT = 'text'
|
95 |
+
TEXT_EMBEDDING = 'text_embedding'
|
96 |
+
IMAGE = 'image'
|
97 |
+
|
98 |
+
def __repr__(self) -> str:
|
99 |
+
return self.value
|
100 |
+
|
101 |
+
|
102 |
+
SIGNAL_TYPE_TO_VALID_DTYPES: dict[SignalInputType, list[DataType]] = {
|
103 |
+
SignalInputType.TEXT: [DataType.STRING, DataType.STRING_SPAN],
|
104 |
+
SignalInputType.IMAGE: [DataType.BINARY],
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
def signal_type_supports_dtype(input_type: SignalInputType, dtype: DataType) -> bool:
|
109 |
+
"""Returns True if the signal compute type supports the dtype."""
|
110 |
+
return dtype in SIGNAL_TYPE_TO_VALID_DTYPES[input_type]
|
111 |
+
|
112 |
+
|
113 |
+
Bin = tuple[str, Optional[Union[float, int]], Optional[Union[float, int]]]
|
114 |
+
|
115 |
+
|
116 |
+
class Field(BaseModel):
|
117 |
+
"""Holds information for a field in the schema."""
|
118 |
+
repeated_field: Optional['Field'] = None
|
119 |
+
fields: Optional[dict[str, 'Field']] = None
|
120 |
+
dtype: Optional[DataType] = None
|
121 |
+
# Defined as the serialized signal when this field is the root result of a signal.
|
122 |
+
signal: Optional[dict[str, Any]] = None
|
123 |
+
# Maps a named bin to a tuple of (start, end) values.
|
124 |
+
bins: Optional[list[Bin]] = None
|
125 |
+
categorical: Optional[bool] = None
|
126 |
+
|
127 |
+
@validator('fields')
|
128 |
+
def either_fields_or_repeated_field_is_defined(
|
129 |
+
cls, fields: Optional[dict[str, 'Field']], values: dict[str,
|
130 |
+
Any]) -> Optional[dict[str, 'Field']]:
|
131 |
+
"""Error if both `fields` and `repeated_fields` are defined."""
|
132 |
+
if not fields:
|
133 |
+
return fields
|
134 |
+
if values.get('repeated_field'):
|
135 |
+
raise ValueError('Both "fields" and "repeated_field" should not be defined')
|
136 |
+
if VALUE_KEY in fields:
|
137 |
+
raise ValueError(f'{VALUE_KEY} is a reserved field name.')
|
138 |
+
return fields
|
139 |
+
|
140 |
+
@validator('dtype', always=True)
|
141 |
+
def infer_default_dtype(cls, dtype: Optional[DataType], values: dict[str,
|
142 |
+
Any]) -> Optional[DataType]:
|
143 |
+
"""Infers the default value for dtype if not explicitly provided."""
|
144 |
+
if dtype and values.get('repeated_field'):
|
145 |
+
raise ValueError('dtype and repeated_field cannot both be defined.')
|
146 |
+
if not values.get('repeated_field') and not values.get('fields') and not dtype:
|
147 |
+
raise ValueError('One of "fields", "repeated_field", or "dtype" should be defined')
|
148 |
+
return dtype
|
149 |
+
|
150 |
+
@validator('bins')
|
151 |
+
def validate_bins(cls, bins: list[Bin]) -> list[Bin]:
|
152 |
+
"""Validate the bins."""
|
153 |
+
if len(bins) < 2:
|
154 |
+
raise ValueError('Please specify at least two bins.')
|
155 |
+
_, first_start, _ = bins[0]
|
156 |
+
if first_start is not None:
|
157 |
+
raise ValueError('The first bin should have a `None` start value.')
|
158 |
+
_, _, last_end = bins[-1]
|
159 |
+
if last_end is not None:
|
160 |
+
raise ValueError('The last bin should have a `None` end value.')
|
161 |
+
for i, (_, start, _) in enumerate(bins):
|
162 |
+
if i == 0:
|
163 |
+
continue
|
164 |
+
prev_bin = bins[i - 1]
|
165 |
+
_, _, prev_end = prev_bin
|
166 |
+
if start != prev_end:
|
167 |
+
raise ValueError(
|
168 |
+
f'Bin {i} start ({start}) should be equal to the previous bin end {prev_end}.')
|
169 |
+
return bins
|
170 |
+
|
171 |
+
@validator('categorical')
|
172 |
+
def validate_categorical(cls, categorical: bool, values: dict[str, Any]) -> bool:
|
173 |
+
"""Validate the categorical field."""
|
174 |
+
if categorical and is_float(values['dtype']):
|
175 |
+
raise ValueError('Categorical fields cannot be float dtypes.')
|
176 |
+
return categorical
|
177 |
+
|
178 |
+
def __str__(self) -> str:
|
179 |
+
return _str_field(self, indent=0)
|
180 |
+
|
181 |
+
def __repr__(self) -> str:
|
182 |
+
return f' {self.__class__.__name__}::{self.json(exclude_none=True, indent=2)}'
|
183 |
+
|
184 |
+
|
185 |
+
class Schema(BaseModel):
|
186 |
+
"""Database schema."""
|
187 |
+
fields: dict[str, Field]
|
188 |
+
# Cached leafs.
|
189 |
+
_leafs: Optional[dict[PathTuple, Field]] = None
|
190 |
+
|
191 |
+
class Config:
|
192 |
+
arbitrary_types_allowed = True
|
193 |
+
underscore_attrs_are_private = True
|
194 |
+
|
195 |
+
@property
|
196 |
+
def leafs(self) -> dict[PathTuple, Field]:
|
197 |
+
"""Return all the leaf fields in the schema. A leaf is defined as a node that contains a value.
|
198 |
+
|
199 |
+
NOTE: Leafs may contain children. Leafs can be found as any node that has a dtype defined.
|
200 |
+
"""
|
201 |
+
if self._leafs:
|
202 |
+
return self._leafs
|
203 |
+
result: dict[PathTuple, Field] = {}
|
204 |
+
q: deque[tuple[PathTuple, Field]] = deque([((), Field(fields=self.fields))])
|
205 |
+
while q:
|
206 |
+
path, field = q.popleft()
|
207 |
+
if field.dtype:
|
208 |
+
# Nodes with dtypes act as leafs. They also may have children.
|
209 |
+
result[path] = field
|
210 |
+
if field.fields:
|
211 |
+
for name, child_field in field.fields.items():
|
212 |
+
child_path = (*path, name)
|
213 |
+
q.append((child_path, child_field))
|
214 |
+
elif field.repeated_field:
|
215 |
+
child_path = (*path, PATH_WILDCARD)
|
216 |
+
q.append((child_path, field.repeated_field))
|
217 |
+
|
218 |
+
self._leafs = result
|
219 |
+
return result
|
220 |
+
|
221 |
+
def has_field(self, path: PathTuple) -> bool:
|
222 |
+
"""Returns if the field is found at the given path."""
|
223 |
+
field = cast(Field, self)
|
224 |
+
for path_part in path:
|
225 |
+
if field.fields:
|
226 |
+
field = cast(Field, field.fields.get(path_part))
|
227 |
+
if not field:
|
228 |
+
return False
|
229 |
+
elif field.repeated_field:
|
230 |
+
if path_part != PATH_WILDCARD:
|
231 |
+
return False
|
232 |
+
field = field.repeated_field
|
233 |
+
else:
|
234 |
+
return False
|
235 |
+
return True
|
236 |
+
|
237 |
+
def get_field(self, path: PathTuple) -> Field:
|
238 |
+
"""Returns the field at the given path."""
|
239 |
+
field = cast(Field, self)
|
240 |
+
for name in path:
|
241 |
+
if field.fields:
|
242 |
+
if name not in field.fields:
|
243 |
+
raise ValueError(f'Path {path} not found in schema')
|
244 |
+
field = field.fields[name]
|
245 |
+
elif field.repeated_field:
|
246 |
+
if name != PATH_WILDCARD:
|
247 |
+
raise ValueError(f'Invalid path {path}')
|
248 |
+
field = field.repeated_field
|
249 |
+
else:
|
250 |
+
raise ValueError(f'Invalid path {path}')
|
251 |
+
return field
|
252 |
+
|
253 |
+
def __str__(self) -> str:
|
254 |
+
return _str_fields(self.fields, indent=0)
|
255 |
+
|
256 |
+
def __repr__(self) -> str:
|
257 |
+
return self.json(exclude_none=True, indent=2)
|
258 |
+
|
259 |
+
|
260 |
+
def schema(schema_like: object) -> Schema:
|
261 |
+
"""Parse a schema-like object to a Schema object."""
|
262 |
+
field = _parse_field_like(schema_like)
|
263 |
+
if not field.fields:
|
264 |
+
raise ValueError('Schema must have fields')
|
265 |
+
return Schema(fields=field.fields)
|
266 |
+
|
267 |
+
|
268 |
+
def field(
|
269 |
+
dtype: Optional[Union[DataType, str]] = None,
|
270 |
+
signal: Optional[dict] = None,
|
271 |
+
fields: Optional[object] = None,
|
272 |
+
bins: Optional[list[Bin]] = None,
|
273 |
+
categorical: Optional[bool] = None,
|
274 |
+
) -> Field:
|
275 |
+
"""Parse a field-like object to a Field object."""
|
276 |
+
field = _parse_field_like(fields or {}, dtype)
|
277 |
+
if signal:
|
278 |
+
field.signal = signal
|
279 |
+
if dtype:
|
280 |
+
if isinstance(dtype, str):
|
281 |
+
dtype = DataType(dtype)
|
282 |
+
field.dtype = dtype
|
283 |
+
if bins:
|
284 |
+
field.bins = bins
|
285 |
+
if categorical is not None:
|
286 |
+
field.categorical = categorical
|
287 |
+
return field
|
288 |
+
|
289 |
+
|
290 |
+
class SpanVector(TypedDict):
|
291 |
+
"""A span with a vector."""
|
292 |
+
span: tuple[int, int]
|
293 |
+
vector: np.ndarray
|
294 |
+
|
295 |
+
|
296 |
+
def lilac_span(start: int, end: int, metadata: dict[str, Any] = {}) -> Item:
|
297 |
+
"""Creates a lilac span item, representing a pointer to a slice of text."""
|
298 |
+
return {VALUE_KEY: {TEXT_SPAN_START_FEATURE: start, TEXT_SPAN_END_FEATURE: end}, **metadata}
|
299 |
+
|
300 |
+
|
301 |
+
def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> Item:
|
302 |
+
"""Creates a lilac embedding item, representing a vector with a pointer to a slice of text."""
|
303 |
+
return lilac_span(start, end, {EMBEDDING_KEY: embedding})
|
304 |
+
|
305 |
+
|
306 |
+
def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field:
|
307 |
+
if isinstance(field_like, Field):
|
308 |
+
return field_like
|
309 |
+
elif isinstance(field_like, dict):
|
310 |
+
fields: dict[str, Field] = {}
|
311 |
+
for k, v in field_like.items():
|
312 |
+
fields[k] = _parse_field_like(v)
|
313 |
+
if isinstance(dtype, str):
|
314 |
+
dtype = DataType(dtype)
|
315 |
+
return Field(fields=fields or None, dtype=dtype)
|
316 |
+
elif isinstance(field_like, str):
|
317 |
+
return Field(dtype=DataType(field_like))
|
318 |
+
elif isinstance(field_like, list):
|
319 |
+
return Field(repeated_field=_parse_field_like(field_like[0], dtype=dtype))
|
320 |
+
else:
|
321 |
+
raise ValueError(f'Cannot parse field like: {field_like}')
|
322 |
+
|
323 |
+
|
324 |
+
def child_item_from_column_path(item: Item, path: Path) -> Item:
|
325 |
+
"""Return the last (child) item from a column path."""
|
326 |
+
child_item_value = item
|
327 |
+
for path_part in path:
|
328 |
+
if path_part == PATH_WILDCARD:
|
329 |
+
raise ValueError(
|
330 |
+
'child_item_from_column_path cannot be called with a path that contains a repeated '
|
331 |
+
f'wildcard: "{path}"')
|
332 |
+
# path_part can either be an integer or a string for a dictionary, both of which we can
|
333 |
+
# directly index with.
|
334 |
+
child_path = int(path_part) if path_part.isdigit() else path_part
|
335 |
+
child_item_value = child_item_value[child_path]
|
336 |
+
return child_item_value
|
337 |
+
|
338 |
+
|
339 |
+
def column_paths_match(path_match: Path, specific_path: Path) -> bool:
|
340 |
+
"""Test whether two column paths match.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
path_match: A column path that contains wildcards, and sub-paths. This path will be used for
|
344 |
+
testing the second specific path.
|
345 |
+
specific_path: A column path that specifically identifies an field.
|
346 |
+
|
347 |
+
Returns
|
348 |
+
Whether specific_path matches the path_match. This will only match when the
|
349 |
+
paths are equal length. If a user wants to enrich everything with an array, they must use the
|
350 |
+
path wildcard '*' in their patch match.
|
351 |
+
"""
|
352 |
+
if isinstance(path_match, str):
|
353 |
+
path_match = (path_match,)
|
354 |
+
if isinstance(specific_path, str):
|
355 |
+
specific_path = (specific_path,)
|
356 |
+
|
357 |
+
if len(path_match) != len(specific_path):
|
358 |
+
return False
|
359 |
+
|
360 |
+
for path_match_p, specific_path_p in zip(path_match, specific_path):
|
361 |
+
if path_match_p == PATH_WILDCARD:
|
362 |
+
continue
|
363 |
+
|
364 |
+
if path_match_p != specific_path_p:
|
365 |
+
return False
|
366 |
+
|
367 |
+
return True
|
368 |
+
|
369 |
+
|
370 |
+
def normalize_path(path: Path) -> PathTuple:
|
371 |
+
"""Normalizes a dot seperated path, but ignores dots inside quotes, like regular SQL.
|
372 |
+
|
373 |
+
Examples
|
374 |
+
- 'a.b.c' will be parsed as ('a', 'b', 'c').
|
375 |
+
- '"a.b".c' will be parsed as ('a.b', 'c').
|
376 |
+
- '"a".b.c' will be parsed as ('a', 'b', 'c').
|
377 |
+
"""
|
378 |
+
if isinstance(path, str):
|
379 |
+
return tuple(next(csv.reader(io.StringIO(path), delimiter='.')))
|
380 |
+
return path
|
381 |
+
|
382 |
+
|
383 |
+
class ImageInfo(BaseModel):
|
384 |
+
"""Info about an individual image."""
|
385 |
+
path: Path
|
386 |
+
|
387 |
+
|
388 |
+
class SourceManifest(BaseModel):
|
389 |
+
"""The manifest that describes the dataset run, including schema and parquet files."""
|
390 |
+
# List of a parquet filepaths storing the data. The paths can be relative to `manifest.json`.
|
391 |
+
files: list[str]
|
392 |
+
# The data schema.
|
393 |
+
data_schema: Schema
|
394 |
+
|
395 |
+
# Image information for the dataset.
|
396 |
+
images: Optional[list[ImageInfo]] = None
|
397 |
+
|
398 |
+
|
399 |
+
def _str_fields(fields: dict[str, Field], indent: int) -> str:
|
400 |
+
prefix = ' ' * indent
|
401 |
+
out: list[str] = []
|
402 |
+
for name, field in fields.items():
|
403 |
+
out.append(f'{prefix}{name}:{_str_field(field, indent=indent + 2)}')
|
404 |
+
return '\n'.join(out)
|
405 |
+
|
406 |
+
|
407 |
+
def _str_field(field: Field, indent: int) -> str:
|
408 |
+
if field.fields:
|
409 |
+
prefix = '\n' if indent > 0 else ''
|
410 |
+
return f'{prefix}{_str_fields(field.fields, indent)}'
|
411 |
+
if field.repeated_field:
|
412 |
+
return f' list({_str_field(field.repeated_field, indent)})'
|
413 |
+
return f' {cast(DataType, field.dtype)}'
|
414 |
+
|
415 |
+
|
416 |
+
def dtype_to_arrow_schema(dtype: DataType) -> Union[pa.Schema, pa.DataType]:
|
417 |
+
"""Convert the dtype to an arrow dtype."""
|
418 |
+
if dtype == DataType.STRING:
|
419 |
+
return pa.string()
|
420 |
+
elif dtype == DataType.BOOLEAN:
|
421 |
+
return pa.bool_()
|
422 |
+
elif dtype == DataType.FLOAT16:
|
423 |
+
return pa.float16()
|
424 |
+
elif dtype == DataType.FLOAT32:
|
425 |
+
return pa.float32()
|
426 |
+
elif dtype == DataType.FLOAT64:
|
427 |
+
return pa.float64()
|
428 |
+
elif dtype == DataType.INT8:
|
429 |
+
return pa.int8()
|
430 |
+
elif dtype == DataType.INT16:
|
431 |
+
return pa.int16()
|
432 |
+
elif dtype == DataType.INT32:
|
433 |
+
return pa.int32()
|
434 |
+
elif dtype == DataType.INT64:
|
435 |
+
return pa.int64()
|
436 |
+
elif dtype == DataType.UINT8:
|
437 |
+
return pa.uint8()
|
438 |
+
elif dtype == DataType.UINT16:
|
439 |
+
return pa.uint16()
|
440 |
+
elif dtype == DataType.UINT32:
|
441 |
+
return pa.uint32()
|
442 |
+
elif dtype == DataType.UINT64:
|
443 |
+
return pa.uint64()
|
444 |
+
elif dtype == DataType.BINARY:
|
445 |
+
return pa.binary()
|
446 |
+
elif dtype == DataType.TIME:
|
447 |
+
return pa.time64()
|
448 |
+
elif dtype == DataType.DATE:
|
449 |
+
return pa.date64()
|
450 |
+
elif dtype == DataType.TIMESTAMP:
|
451 |
+
return pa.timestamp('us')
|
452 |
+
elif dtype == DataType.INTERVAL:
|
453 |
+
return pa.duration('us')
|
454 |
+
elif dtype == DataType.EMBEDDING:
|
455 |
+
# We reserve an empty column for embeddings in parquet files so they can be queried.
|
456 |
+
# The values are *not* filled out. If parquet and duckdb support embeddings in the future, we
|
457 |
+
# can set this dtype to the relevant pyarrow type.
|
458 |
+
return pa.null()
|
459 |
+
elif dtype == DataType.STRING_SPAN:
|
460 |
+
return pa.struct({
|
461 |
+
VALUE_KEY: pa.struct({
|
462 |
+
TEXT_SPAN_START_FEATURE: pa.int32(),
|
463 |
+
TEXT_SPAN_END_FEATURE: pa.int32()
|
464 |
+
})
|
465 |
+
})
|
466 |
+
elif dtype == DataType.NULL:
|
467 |
+
return pa.null()
|
468 |
+
else:
|
469 |
+
raise ValueError(f'Can not convert dtype "{dtype}" to arrow dtype')
|
470 |
+
|
471 |
+
|
472 |
+
def schema_to_arrow_schema(schema: Union[Schema, Field]) -> pa.Schema:
|
473 |
+
"""Convert our schema to arrow schema."""
|
474 |
+
arrow_schema = cast(pa.Schema, _schema_to_arrow_schema_impl(schema))
|
475 |
+
arrow_fields = {field.name: field.type for field in arrow_schema}
|
476 |
+
return pa.schema(arrow_fields)
|
477 |
+
|
478 |
+
|
479 |
+
def _schema_to_arrow_schema_impl(schema: Union[Schema, Field]) -> Union[pa.Schema, pa.DataType]:
|
480 |
+
"""Convert a schema to an apache arrow schema."""
|
481 |
+
if schema.fields:
|
482 |
+
arrow_fields: dict[str, Union[pa.Schema, pa.DataType]] = {}
|
483 |
+
for name, field in schema.fields.items():
|
484 |
+
if name == UUID_COLUMN:
|
485 |
+
arrow_schema = dtype_to_arrow_schema(cast(DataType, field.dtype))
|
486 |
+
else:
|
487 |
+
arrow_schema = _schema_to_arrow_schema_impl(field)
|
488 |
+
arrow_fields[name] = arrow_schema
|
489 |
+
|
490 |
+
if isinstance(schema, Schema):
|
491 |
+
# Top-level schemas do not have __value__ fields.
|
492 |
+
return pa.schema(arrow_fields)
|
493 |
+
else:
|
494 |
+
# When nodes have both dtype and children, we add __value__ alongside the fields.
|
495 |
+
if schema.dtype:
|
496 |
+
value_schema = dtype_to_arrow_schema(schema.dtype)
|
497 |
+
if schema.dtype == DataType.STRING_SPAN:
|
498 |
+
value_schema = value_schema[VALUE_KEY].type
|
499 |
+
arrow_fields[VALUE_KEY] = value_schema
|
500 |
+
|
501 |
+
return pa.struct(arrow_fields)
|
502 |
+
|
503 |
+
field = cast(Field, schema)
|
504 |
+
if field.repeated_field:
|
505 |
+
return pa.list_(_schema_to_arrow_schema_impl(field.repeated_field))
|
506 |
+
|
507 |
+
return dtype_to_arrow_schema(cast(DataType, field.dtype))
|
508 |
+
|
509 |
+
|
510 |
+
def arrow_dtype_to_dtype(arrow_dtype: pa.DataType) -> DataType:
|
511 |
+
"""Convert arrow dtype to our dtype."""
|
512 |
+
# Ints.
|
513 |
+
if arrow_dtype == pa.int8():
|
514 |
+
return DataType.INT8
|
515 |
+
elif arrow_dtype == pa.int16():
|
516 |
+
return DataType.INT16
|
517 |
+
elif arrow_dtype == pa.int32():
|
518 |
+
return DataType.INT32
|
519 |
+
elif arrow_dtype == pa.int64():
|
520 |
+
return DataType.INT64
|
521 |
+
elif arrow_dtype == pa.uint8():
|
522 |
+
return DataType.UINT8
|
523 |
+
elif arrow_dtype == pa.uint16():
|
524 |
+
return DataType.UINT16
|
525 |
+
elif arrow_dtype == pa.uint32():
|
526 |
+
return DataType.UINT32
|
527 |
+
elif arrow_dtype == pa.uint64():
|
528 |
+
return DataType.UINT64
|
529 |
+
# Floats.
|
530 |
+
elif arrow_dtype == pa.float16():
|
531 |
+
return DataType.FLOAT16
|
532 |
+
elif arrow_dtype == pa.float32():
|
533 |
+
return DataType.FLOAT32
|
534 |
+
elif arrow_dtype == pa.float64():
|
535 |
+
return DataType.FLOAT64
|
536 |
+
# Time.
|
537 |
+
elif pa.types.is_time(arrow_dtype):
|
538 |
+
return DataType.TIME
|
539 |
+
elif pa.types.is_date(arrow_dtype):
|
540 |
+
return DataType.DATE
|
541 |
+
elif pa.types.is_timestamp(arrow_dtype):
|
542 |
+
return DataType.TIMESTAMP
|
543 |
+
elif pa.types.is_duration(arrow_dtype):
|
544 |
+
return DataType.INTERVAL
|
545 |
+
# Others.
|
546 |
+
elif arrow_dtype == pa.string():
|
547 |
+
return DataType.STRING
|
548 |
+
elif pa.types.is_binary(arrow_dtype) or pa.types.is_fixed_size_binary(arrow_dtype):
|
549 |
+
return DataType.BINARY
|
550 |
+
elif pa.types.is_boolean(arrow_dtype):
|
551 |
+
return DataType.BOOLEAN
|
552 |
+
elif arrow_dtype == pa.null():
|
553 |
+
return DataType.NULL
|
554 |
+
else:
|
555 |
+
raise ValueError(f'Can not convert arrow dtype "{arrow_dtype}" to our dtype')
|
556 |
+
|
557 |
+
|
558 |
+
def arrow_schema_to_schema(schema: pa.Schema) -> Schema:
|
559 |
+
"""Convert arrow schema to our schema."""
|
560 |
+
# TODO(nsthorat): Change this implementation to allow more complicated reading of arrow schemas
|
561 |
+
# into our schema by inferring values when {__value__: value} is present in the pyarrow schema.
|
562 |
+
# This isn't necessary today as this util is only needed by sources which do not have data in the
|
563 |
+
# lilac format.
|
564 |
+
return cast(Schema, _arrow_schema_to_schema_impl(schema))
|
565 |
+
|
566 |
+
|
567 |
+
def _arrow_schema_to_schema_impl(schema: Union[pa.Schema, pa.DataType]) -> Union[Schema, Field]:
|
568 |
+
"""Convert an apache arrow schema to our schema."""
|
569 |
+
if isinstance(schema, (pa.Schema, pa.StructType)):
|
570 |
+
fields: dict[str, Field] = {
|
571 |
+
field.name: cast(Field, _arrow_schema_to_schema_impl(field.type)) for field in schema
|
572 |
+
}
|
573 |
+
return Schema(fields=fields) if isinstance(schema, pa.Schema) else Field(fields=fields)
|
574 |
+
elif isinstance(schema, pa.ListType):
|
575 |
+
return Field(repeated_field=cast(Field, _arrow_schema_to_schema_impl(schema.value_field.type)))
|
576 |
+
else:
|
577 |
+
return Field(dtype=arrow_dtype_to_dtype(schema))
|
578 |
+
|
579 |
+
|
580 |
+
def is_float(dtype: DataType) -> bool:
|
581 |
+
"""Check if a dtype is a float dtype."""
|
582 |
+
return dtype in [DataType.FLOAT16, DataType.FLOAT32, DataType.FLOAT64]
|
583 |
+
|
584 |
+
|
585 |
+
def is_integer(dtype: DataType) -> bool:
|
586 |
+
"""Check if a dtype is an integer dtype."""
|
587 |
+
return dtype in [
|
588 |
+
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16,
|
589 |
+
DataType.UINT32, DataType.UINT64
|
590 |
+
]
|
591 |
+
|
592 |
+
|
593 |
+
def is_temporal(dtype: DataType) -> bool:
|
594 |
+
"""Check if a dtype is a temporal dtype."""
|
595 |
+
return dtype in [DataType.TIME, DataType.DATE, DataType.TIMESTAMP, DataType.INTERVAL]
|
596 |
+
|
597 |
+
|
598 |
+
def is_ordinal(dtype: DataType) -> bool:
|
599 |
+
"""Check if a dtype is an ordinal dtype."""
|
600 |
+
return is_float(dtype) or is_integer(dtype) or is_temporal(dtype)
|