Spaces:
Running
Running
added hf inference api
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- duckdb-nsql/eval/get_manifest.py +1 -1
- duckdb-nsql/eval/predict.py +2 -2
- duckdb-nsql/manifest/.flake8 +0 -11
- duckdb-nsql/manifest/.pre-commit-config.yaml +0 -23
- duckdb-nsql/manifest/CHANGELOG.rst +0 -93
- duckdb-nsql/manifest/LICENSE +0 -201
- duckdb-nsql/manifest/Makefile +0 -27
- duckdb-nsql/manifest/README.md +0 -304
- duckdb-nsql/manifest/examples/langchain_chatgpt.ipynb +0 -455
- duckdb-nsql/manifest/examples/manifest_async.py +0 -27
- duckdb-nsql/manifest/examples/manifest_azure.ipynb +0 -149
- duckdb-nsql/manifest/examples/manifest_chatgpt.ipynb +0 -101
- duckdb-nsql/manifest/examples/manifest_connection_pool.ipynb +0 -208
- duckdb-nsql/manifest/examples/manifest_diffusers.ipynb +0 -0
- duckdb-nsql/manifest/examples/manifest_embedding.ipynb +0 -156
- duckdb-nsql/manifest/examples/manifest_google.ipynb +0 -117
- duckdb-nsql/manifest/examples/manifest_openrouter.ipynb +0 -108
- duckdb-nsql/manifest/examples/manifest_streaming.ipynb +0 -105
- duckdb-nsql/manifest/examples/manifest_together.ipynb +0 -106
- duckdb-nsql/manifest/manifest/__init__.py +0 -6
- duckdb-nsql/manifest/manifest/api/__init__.py +0 -1
- duckdb-nsql/manifest/manifest/api/app.py +0 -301
- duckdb-nsql/manifest/manifest/api/models/__init__.py +0 -1
- duckdb-nsql/manifest/manifest/api/models/diffuser.py +0 -123
- duckdb-nsql/manifest/manifest/api/models/huggingface.py +0 -671
- duckdb-nsql/manifest/manifest/api/models/model.py +0 -91
- duckdb-nsql/manifest/manifest/api/models/sentence_transformer.py +0 -113
- duckdb-nsql/manifest/manifest/api/response.py +0 -55
- duckdb-nsql/manifest/manifest/caches/__init__.py +0 -1
- duckdb-nsql/manifest/manifest/caches/array_cache.py +0 -116
- duckdb-nsql/manifest/manifest/caches/cache.py +0 -135
- duckdb-nsql/manifest/manifest/caches/noop.py +0 -47
- duckdb-nsql/manifest/manifest/caches/postgres.py +0 -131
- duckdb-nsql/manifest/manifest/caches/redis.py +0 -64
- duckdb-nsql/manifest/manifest/caches/serializers.py +0 -204
- duckdb-nsql/manifest/manifest/caches/sqlite.py +0 -65
- duckdb-nsql/manifest/manifest/clients/__init__.py +0 -1
- duckdb-nsql/manifest/manifest/clients/ai21.py +0 -125
- duckdb-nsql/manifest/manifest/clients/azureendpoint.py +0 -139
- duckdb-nsql/manifest/manifest/clients/azureopenai.py +0 -113
- duckdb-nsql/manifest/manifest/clients/azureopenai_chat.py +0 -116
- duckdb-nsql/manifest/manifest/clients/client.py +0 -699
- duckdb-nsql/manifest/manifest/clients/cohere.py +0 -125
- duckdb-nsql/manifest/manifest/clients/diffuser.py +0 -112
- duckdb-nsql/manifest/manifest/clients/dummy.py +0 -251
- duckdb-nsql/manifest/manifest/clients/google.py +0 -197
- duckdb-nsql/manifest/manifest/clients/google_chat.py +0 -155
- duckdb-nsql/manifest/manifest/clients/huggingface.py +0 -137
- duckdb-nsql/manifest/manifest/clients/huggingface_embedding.py +0 -98
- duckdb-nsql/manifest/manifest/clients/openai.py +0 -162
duckdb-nsql/eval/get_manifest.py
CHANGED
@@ -9,7 +9,7 @@ def get_manifest(
|
|
9 |
manifest_engine: str,
|
10 |
) -> Manifest:
|
11 |
"""Get manifest engine."""
|
12 |
-
if manifest_client in {"openai", "openaichat", "openai_mock", "openrouter", "azureendpoint"}:
|
13 |
manifest = Manifest(
|
14 |
client_name=manifest_client,
|
15 |
engine=manifest_engine,
|
|
|
9 |
manifest_engine: str,
|
10 |
) -> Manifest:
|
11 |
"""Get manifest engine."""
|
12 |
+
if manifest_client in {"openai", "openaichat", "openai_mock", "openrouter", "azureendpoint", "inference_api"}:
|
13 |
manifest = Manifest(
|
14 |
client_name=manifest_client,
|
15 |
engine=manifest_engine,
|
duckdb-nsql/eval/predict.py
CHANGED
@@ -213,7 +213,7 @@ def predict(
|
|
213 |
console.print(f"Running with {manifest_params} manifest.")
|
214 |
model_name = manifest_params.get("engine", manifest_params["model_name"])
|
215 |
|
216 |
-
if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint"}:
|
217 |
tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
|
218 |
else:
|
219 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
@@ -234,7 +234,7 @@ def predict(
|
|
234 |
middleix = manifest_engine
|
235 |
elif manifest_client in {"huggingface", "ray"}:
|
236 |
middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-")
|
237 |
-
elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint"}:
|
238 |
middleix = manifest_engine.split("/")[-1]
|
239 |
else:
|
240 |
raise ValueError(f"Unknown manifest client {manifest_client}")
|
|
|
213 |
console.print(f"Running with {manifest_params} manifest.")
|
214 |
model_name = manifest_params.get("engine", manifest_params["model_name"])
|
215 |
|
216 |
+
if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint", "inference_api"}:
|
217 |
tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
|
218 |
else:
|
219 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
|
234 |
middleix = manifest_engine
|
235 |
elif manifest_client in {"huggingface", "ray"}:
|
236 |
middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-")
|
237 |
+
elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint", "inference_api"}:
|
238 |
middleix = manifest_engine.split("/")[-1]
|
239 |
else:
|
240 |
raise ValueError(f"Unknown manifest client {manifest_client}")
|
duckdb-nsql/manifest/.flake8
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# This is our code-style check. We currently allow the following exceptions:
|
2 |
-
# - E731: do not assign a lambda expression, use a def
|
3 |
-
# - E402: module level import not at top of file
|
4 |
-
# - W503: line break before binary operator
|
5 |
-
# - E203: whitespace before :
|
6 |
-
|
7 |
-
[flake8]
|
8 |
-
exclude = .git
|
9 |
-
max-line-length = 88
|
10 |
-
ignore = E731, E402, W503, E203, PAI100, PAI101, PAI201, PAI202, PAI203
|
11 |
-
per-file-ignores = __init__.py:F401, version.py:D100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/.pre-commit-config.yaml
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
repos:
|
2 |
-
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
-
rev: v3.2.0
|
4 |
-
hooks:
|
5 |
-
- id: trailing-whitespace
|
6 |
-
- id: end-of-file-fixer
|
7 |
-
- id: check-yaml
|
8 |
-
- id: check-toml
|
9 |
-
- id: check-merge-conflict
|
10 |
-
- id: check-added-large-files
|
11 |
-
- repo: https://github.com/timothycrosley/isort
|
12 |
-
rev: 5.13.2
|
13 |
-
hooks:
|
14 |
-
- id: isort
|
15 |
-
- repo: https://github.com/psf/black
|
16 |
-
rev: 22.3.0
|
17 |
-
hooks:
|
18 |
-
- id: black
|
19 |
-
language_version: python3
|
20 |
-
- repo: https://github.com/PyCQA/flake8
|
21 |
-
rev: 6.0.0
|
22 |
-
hooks:
|
23 |
-
- id: flake8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/CHANGELOG.rst
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
0.1.10 - Unreleased
|
2 |
-
---------------------
|
3 |
-
|
4 |
-
0.1.9 - 2024-01-22
|
5 |
-
---------------------
|
6 |
-
Fixed
|
7 |
-
^^^^^
|
8 |
-
* Added trust code params HF models
|
9 |
-
* Added LRU cache to HF model param calls to avoid extra calls
|
10 |
-
* Fixed pydantic type issue HF model return
|
11 |
-
* Support for Python 3.10-3.11
|
12 |
-
|
13 |
-
0.1.8 - 2023-05-22
|
14 |
-
---------------------
|
15 |
-
Added
|
16 |
-
^^^^^
|
17 |
-
* Azure model support (completion and chat)
|
18 |
-
* Google Vertex API model support (completion and chat)
|
19 |
-
* Streaming responses for LM Completions (set stream=True)
|
20 |
-
|
21 |
-
Fixed
|
22 |
-
^^^^^
|
23 |
-
* `run` with batches now acts the same as async run except not async. We will batch requests into appropriate batchs sizes.
|
24 |
-
* Refactored client so unified preprocess and postprocess of requests and responses to better support model variants in request/response format.
|
25 |
-
|
26 |
-
0.1.7 - 2023-05-17
|
27 |
-
---------------------
|
28 |
-
Fixed
|
29 |
-
^^^^^
|
30 |
-
* `_run_chat` fixed bug where not passing in kwargs
|
31 |
-
|
32 |
-
0.1.6 - 2023-05-16
|
33 |
-
---------------------
|
34 |
-
Fixed
|
35 |
-
^^^^^
|
36 |
-
* Unified `run` and `run_chat` methods so it's just `run` now.
|
37 |
-
* LLama HF models for eval
|
38 |
-
|
39 |
-
0.1.5 - 2023-05-03
|
40 |
-
---------------------
|
41 |
-
Added
|
42 |
-
^^^^^
|
43 |
-
* Added chat input for chat models.
|
44 |
-
|
45 |
-
0.1.4 - 2023-04-24
|
46 |
-
---------------------
|
47 |
-
Added
|
48 |
-
^^^^^
|
49 |
-
* Connection pools to swap between clients
|
50 |
-
* Chunksize param for async runs
|
51 |
-
|
52 |
-
Fixed
|
53 |
-
^^^^^
|
54 |
-
* Determine cache and response by request type, not client name
|
55 |
-
* Refactor Response to use Pydantic types for Request and Response
|
56 |
-
|
57 |
-
0.1.1
|
58 |
-
---------------------
|
59 |
-
Added
|
60 |
-
^^^^^
|
61 |
-
* Async support in arun_batch
|
62 |
-
|
63 |
-
Fixed
|
64 |
-
^^^^^
|
65 |
-
* Batched runs now caches individual items
|
66 |
-
* Score prompt does not truncate outside token
|
67 |
-
|
68 |
-
Removed
|
69 |
-
^^^^^
|
70 |
-
* Deprecated chatGPT in favor of openaichat which uses OpenAI completions
|
71 |
-
* Deprecated Sessions
|
72 |
-
|
73 |
-
0.1.0 - 2022-01-31
|
74 |
-
---------------------
|
75 |
-
Added
|
76 |
-
^^^^^
|
77 |
-
* Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method.
|
78 |
-
* Standard request base model for all language inputs.
|
79 |
-
* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
|
80 |
-
* Diffusion model support
|
81 |
-
* Together model support
|
82 |
-
|
83 |
-
Removed
|
84 |
-
^^^^^^^
|
85 |
-
* `Prompt` class
|
86 |
-
* `OPT` client - OPT is now available in HuggingFace
|
87 |
-
|
88 |
-
0.0.1 - 2022-11-08
|
89 |
-
-------------------
|
90 |
-
First major pip release of Manifest. Install via `pip install manifest-ml`.
|
91 |
-
|
92 |
-
|
93 |
-
.. _@lorr1: https://github.com/lorr1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/Makefile
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
dev:
|
2 |
-
pip install -e .[all]
|
3 |
-
pre-commit install
|
4 |
-
|
5 |
-
test: dev check
|
6 |
-
pytest tests
|
7 |
-
|
8 |
-
format:
|
9 |
-
isort --atomic manifest/ tests/ web_app/
|
10 |
-
black manifest/ tests/ web_app/
|
11 |
-
|
12 |
-
check:
|
13 |
-
isort -c manifest/ tests/ web_app/
|
14 |
-
black manifest/ tests/ web_app/ --check
|
15 |
-
flake8 manifest/ tests/ web_app/
|
16 |
-
mypy manifest/ tests/ web_app/
|
17 |
-
|
18 |
-
clean:
|
19 |
-
pip uninstall -y manifest
|
20 |
-
rm -rf src/manifest.egg-info
|
21 |
-
rm -rf build/ dist/
|
22 |
-
|
23 |
-
prune:
|
24 |
-
@bash -c "git fetch -p";
|
25 |
-
@bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done";
|
26 |
-
|
27 |
-
.PHONY: dev test clean check prune
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/README.md
DELETED
@@ -1,304 +0,0 @@
|
|
1 |
-
# Manifest
|
2 |
-
How to make prompt programming with Foundation Models a little easier.
|
3 |
-
|
4 |
-
|
5 |
-
# Table of Contents
|
6 |
-
- [Install](#install)
|
7 |
-
- [Getting Started](#getting-started)
|
8 |
-
- [Manifest](#manifest-components)
|
9 |
-
- [Other Models Types](#other-models)
|
10 |
-
- [Local HuggingFace Models](#local-huggingface-models)
|
11 |
-
- [Chat Models](#chat-models)
|
12 |
-
- [Embedding Models](#embedding-models)
|
13 |
-
- [Road Map](#road-map)
|
14 |
-
- [Development](#development)
|
15 |
-
- [Cite](#cite)
|
16 |
-
|
17 |
-
|
18 |
-
# Install
|
19 |
-
Install:
|
20 |
-
```bash
|
21 |
-
pip install manifest-ml
|
22 |
-
```
|
23 |
-
|
24 |
-
Install with diffusion support:
|
25 |
-
```bash
|
26 |
-
pip install manifest-ml[diffusers]
|
27 |
-
```
|
28 |
-
|
29 |
-
Install with HuggingFace local model support:
|
30 |
-
```bash
|
31 |
-
pip install manifest-ml[api]
|
32 |
-
```
|
33 |
-
|
34 |
-
Dev Install:
|
35 |
-
```bash
|
36 |
-
git clone [email protected]:HazyResearch/manifest.git
|
37 |
-
cd manifest
|
38 |
-
make dev
|
39 |
-
```
|
40 |
-
|
41 |
-
# Getting Started
|
42 |
-
Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` (or pass key in through variable `client_connection`) then run
|
43 |
-
|
44 |
-
```python
|
45 |
-
from manifest import Manifest
|
46 |
-
|
47 |
-
# Start a manifest session to OpenAI - default `engine=text-davinci-003`
|
48 |
-
manifest = Manifest(
|
49 |
-
client_name = "openai",
|
50 |
-
)
|
51 |
-
manifest.run("Why is the grass green?")
|
52 |
-
```
|
53 |
-
|
54 |
-
## Examples
|
55 |
-
We have example notebook and python scripts located at [examples](examples). These show how to use different models, model types (i.e. text, diffusers, or embedding models), and async running.
|
56 |
-
|
57 |
-
# Manifest Components
|
58 |
-
Manifest is meant to be a very light weight package to help with prompt design and iteration. Three key design decisions of Manifest are
|
59 |
-
|
60 |
-
* All models are behind APIs
|
61 |
-
* Supports caching of model inputs/outputs for iteration, reproducibility, and cost saving
|
62 |
-
* Unified API to support generate, score, and embed
|
63 |
-
|
64 |
-
## Models
|
65 |
-
Manifest provides model clients for [OpenAI](https://openai.com/), [AI21](https://studio.ai21.com/), [Cohere](https://cohere.ai/), [Together](https://together.xyz/), and HuggingFace (see [below](#huggingface-models) for how to use locally hosted HuggingFace models). You can toggle between the models by changing `client_name` and `client_connection`. For example, if a HuggingFace model is loaded locally, run
|
66 |
-
```python
|
67 |
-
manifest = Manifest(
|
68 |
-
client_name = "huggingface",
|
69 |
-
client_connection = "http://127.0.0.1:5000",
|
70 |
-
)
|
71 |
-
```
|
72 |
-
If you want to use Cohere, run
|
73 |
-
```python
|
74 |
-
manifest = Manifest(
|
75 |
-
client_name = "cohere",
|
76 |
-
client_connection = <COHERE_API_KEY>,
|
77 |
-
)
|
78 |
-
```
|
79 |
-
You can also just set `export COHERE_API_KEY=<COHERE_API_KEY>` and not use `client_connection`.
|
80 |
-
|
81 |
-
If you want to use AI21 Labs, run
|
82 |
-
```python
|
83 |
-
manifest = Manifest(
|
84 |
-
client_name = "ai21",
|
85 |
-
client_connection = <AI21_API_KEY>,
|
86 |
-
)
|
87 |
-
```
|
88 |
-
|
89 |
-
You can see the model details and possible model inputs to `run()` via
|
90 |
-
```python
|
91 |
-
print(manifest.client_pool.get_current_client().get_model_params())
|
92 |
-
print(manifest.client_pool.get_current_client().get_model_inputs())
|
93 |
-
```
|
94 |
-
|
95 |
-
## Global Cache
|
96 |
-
We support having queries and results stored in a global cache that can be shared across users. We treat inputs and outputs as key value pairs and support SQLite or Redis backends. To start with global caching using SQLite, run
|
97 |
-
|
98 |
-
```python
|
99 |
-
manifest = Manifest(
|
100 |
-
client_name = "openai",
|
101 |
-
cache_name = "sqlite",
|
102 |
-
cache_connection = "mycache.sqlite",
|
103 |
-
)
|
104 |
-
```
|
105 |
-
The cache will be saved in `mycache.sqlite`.
|
106 |
-
|
107 |
-
We also support Redis backend.
|
108 |
-
```python
|
109 |
-
manifest = Manifest(
|
110 |
-
client_name = "openai",
|
111 |
-
cache_name = "redis",
|
112 |
-
cache_connection = "localhost:6379"
|
113 |
-
)
|
114 |
-
```
|
115 |
-
As a hint, if you want to get Redis running, see the `docker run` command below under development.
|
116 |
-
|
117 |
-
## Running Queries
|
118 |
-
Once you have a session open, you can write and develop prompts.
|
119 |
-
|
120 |
-
```python
|
121 |
-
result = manifest.run("Hello, my name is Laurel")
|
122 |
-
```
|
123 |
-
|
124 |
-
You can also run over multiple examples if supported by the client.
|
125 |
-
```python
|
126 |
-
results = manifest.run(["Where are the cats?", "Where are the dogs?"])
|
127 |
-
```
|
128 |
-
|
129 |
-
We support async queries as well via
|
130 |
-
```python
|
131 |
-
import asyncio
|
132 |
-
results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the dogs?"]))
|
133 |
-
```
|
134 |
-
|
135 |
-
If something doesn't go right, you can also ask to get a raw manifest Response.
|
136 |
-
```python
|
137 |
-
result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
|
138 |
-
print(result_object.get_request_obj())
|
139 |
-
print(result_object.is_cached())
|
140 |
-
print(result_object.get_response_obj())
|
141 |
-
```
|
142 |
-
|
143 |
-
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`.
|
144 |
-
```python
|
145 |
-
result = manifest.run(prompt, "Laurel", stop_token="and")
|
146 |
-
```
|
147 |
-
|
148 |
-
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
|
149 |
-
```python
|
150 |
-
result = manifest.run(prompt, "Laurel", max_tokens=50)
|
151 |
-
```
|
152 |
-
|
153 |
-
## Streaming Queries
|
154 |
-
Manifest also supports streaming the model response back, assuming it's supported by the underlying client. When calling `run`, pass `stream=True` to get a streaming iterator in response.
|
155 |
-
|
156 |
-
```python
|
157 |
-
result_iterator = manifest.run("Tell me a story. Once upon a time", max_tokens=100, stream=True)
|
158 |
-
for res_text in result_iterator:
|
159 |
-
print(res_text)
|
160 |
-
```
|
161 |
-
Streaming responses are only supported for single string queries (not batch mode) for text completion models.
|
162 |
-
|
163 |
-
## Model Pools
|
164 |
-
Manifest supports querying multiple models with different schedulers. This is very much a work in progress effort, but Manifest will round robin select (or randomly select) the clients you want. You can use the same client multiple times with different connection strings (e.g. different API keys), or you can mix and match. The only requirement is that all clients are the same request type. I.e. you can't have a pool of generation models and embedding models.
|
165 |
-
|
166 |
-
To query between a local model and OpenAI,
|
167 |
-
```python
|
168 |
-
from manifest.connections.client_pool import ClientConnection
|
169 |
-
from manifest import Manifest
|
170 |
-
|
171 |
-
client_connection1 = ClientConnection(
|
172 |
-
client_name="huggingface",
|
173 |
-
client_connection="http://127.0.0.1:5000",
|
174 |
-
)
|
175 |
-
client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001")
|
176 |
-
manifest = Manifest(
|
177 |
-
client_pool=[client_connection1, client_connection2],
|
178 |
-
cache_name="sqlite",
|
179 |
-
client_connection=sqlite_cache,
|
180 |
-
)
|
181 |
-
manifest.run(...)
|
182 |
-
```
|
183 |
-
|
184 |
-
The speed benefit comes in with async batched runs. When calling `arun_batch` with a list of prompts, Manifest supports a `chunk_size` param. This will break the prompts into `chunk_size` chunks to spread across the client pool. By default `chunk_size` is `-1` which means only one client will get all the prompts to run asynchronously. You must set `chunk_size > 1` to distribute across the pool. There is a further `batch_size` param which control the individual client `batch_size` to send to the model.
|
185 |
-
|
186 |
-
```python
|
187 |
-
responses = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))
|
188 |
-
```
|
189 |
-
|
190 |
-
# Other Models
|
191 |
-
|
192 |
-
## Local Huggingface Models
|
193 |
-
To use a HuggingFace generative model, in `manifest/api` we have a Flask application that hosts the models for you.
|
194 |
-
|
195 |
-
In a separate terminal or Tmux/Screen session, to load 6B parameters models, run
|
196 |
-
```bash
|
197 |
-
python3 -m manifest.api.app \
|
198 |
-
--model_type huggingface \
|
199 |
-
--model_name_or_path EleutherAI/gpt-j-6B \
|
200 |
-
--device 0
|
201 |
-
```
|
202 |
-
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
|
203 |
-
|
204 |
-
```python
|
205 |
-
manifest = Manifest(
|
206 |
-
client_name = "huggingface",
|
207 |
-
client_connection = "http://127.0.0.1:5000",
|
208 |
-
)
|
209 |
-
```
|
210 |
-
|
211 |
-
If you have a custom model you trained, pass the model path to `--model_name_or_path`.
|
212 |
-
|
213 |
-
To help load larger models, we also support using `parallelize()` from HF, [accelerate](https://huggingface.co/docs/accelerate/index), [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), and [deepspeed](https://github.com/microsoft/DeepSpeed). You will need to install these packages first via `pip install manifest-ml[api]`. We list the commands to load larger models below.
|
214 |
-
|
215 |
-
* T0pp
|
216 |
-
```bash
|
217 |
-
python3 -m manifest.api.app \
|
218 |
-
--model_type huggingface \
|
219 |
-
--model_name_or_path bigscience/T0pp \
|
220 |
-
--use_hf_parallelize
|
221 |
-
```
|
222 |
-
|
223 |
-
* NeoX 20B (requires at least 60GB of GPU memory)
|
224 |
-
```bash
|
225 |
-
python3 -m manifest.api.app \
|
226 |
-
--model_type huggingface \
|
227 |
-
--model_name_or_path EleutherAI/gpt-neox-20b \
|
228 |
-
--use_accelerate_multigpu \
|
229 |
-
--percent_max_gpu_mem_reduction 0.75
|
230 |
-
```
|
231 |
-
* Bloom 175B (requires at least 240GB of GPU memory)
|
232 |
-
```bash
|
233 |
-
python3 -m manifest.api.app \
|
234 |
-
--model_type huggingface \
|
235 |
-
--model_name_or_path bigscience/bloom \
|
236 |
-
--use_bitsandbytes \
|
237 |
-
--percent_max_gpu_mem_reduction 0.85
|
238 |
-
```
|
239 |
-
|
240 |
-
## Chat Models
|
241 |
-
Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, use the `run` command with a list of dictionary inputs with `role` and `content` keys using an associated chat model such as `openaichat`.
|
242 |
-
|
243 |
-
```python
|
244 |
-
manifest = Manifest(client_name="openaichat")
|
245 |
-
dialogue = [
|
246 |
-
{"role": "system", "content": "You are a helpful assistant who also responds in rhymes"},
|
247 |
-
{"role": "user", "content": "What is the date?"},
|
248 |
-
]
|
249 |
-
res = manifest.run(dialogue, max_tokens=100)
|
250 |
-
```
|
251 |
-
|
252 |
-
## Embedding Models
|
253 |
-
Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.
|
254 |
-
|
255 |
-
To use OpenAI's embedding models, simply run
|
256 |
-
```python
|
257 |
-
manifest = Manifest(client_name="openaiembedding")
|
258 |
-
embedding_as_np = manifest.run("Get me an embedding for a bunny")
|
259 |
-
```
|
260 |
-
|
261 |
-
As explained above, you can load local HuggingFace models that give you embeddings, too. If you want to use a standard generative model, load the model as above use use `client_name="huggingfaceembedding"`. If you want to use a standard embedding model, like those from SentenceTransformers, load your local model via
|
262 |
-
```bash
|
263 |
-
python3 -m manifest.api.app \
|
264 |
-
--model_type sentence_transformers \
|
265 |
-
--model_name_or_path all-mpnet-base-v2 \
|
266 |
-
--device 0
|
267 |
-
```
|
268 |
-
|
269 |
-
# Road Map
|
270 |
-
Here's what's coming up next
|
271 |
-
- [ ] Clients
|
272 |
-
- [ ] HuggingFace Hub
|
273 |
-
- [x] Azure OpenAI
|
274 |
-
- [x] Google Vertex
|
275 |
-
- [ ] Anthropic
|
276 |
-
- [x] Streaming Support Completions
|
277 |
-
- [ ] Streaming Support Chat Models
|
278 |
-
- [ ] Data Types
|
279 |
-
- [ ] Diffusion Models
|
280 |
-
- [x] Orchestration
|
281 |
-
- [x] Connection pools
|
282 |
-
- [ ] Local Inference
|
283 |
-
- [ ] FlexGen
|
284 |
-
|
285 |
-
# Development
|
286 |
-
Before submitting a PR, run
|
287 |
-
```bash
|
288 |
-
export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests
|
289 |
-
cd <REDIS_PATH>
|
290 |
-
docker run -d -p 127.0.0.1:${REDIS_PORT}:6379 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
|
291 |
-
make test
|
292 |
-
```
|
293 |
-
|
294 |
-
# Cite
|
295 |
-
Please cite Manifest if you used it for any publications. Thanks!!
|
296 |
-
```
|
297 |
-
@misc{orr2022manifest,
|
298 |
-
author = {Orr, Laurel},
|
299 |
-
title = {Manifest},
|
300 |
-
year = {2022},
|
301 |
-
publisher = {GitHub},
|
302 |
-
howpublished = {\url{https://github.com/HazyResearch/manifest}},
|
303 |
-
}
|
304 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/langchain_chatgpt.ipynb
DELETED
@@ -1,455 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"attachments": {},
|
5 |
-
"cell_type": "markdown",
|
6 |
-
"id": "b253f4d5",
|
7 |
-
"metadata": {},
|
8 |
-
"source": [
|
9 |
-
"# ChatGPT Clone using TOMA GPT-JT-6B\n",
|
10 |
-
"(adopted from ChatGPT Clone [notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/chains/chatgpt_clone.ipynb))"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": 1,
|
16 |
-
"id": "b0302886",
|
17 |
-
"metadata": {},
|
18 |
-
"outputs": [
|
19 |
-
{
|
20 |
-
"name": "stdout",
|
21 |
-
"output_type": "stream",
|
22 |
-
"text": [
|
23 |
-
"env: TOMA_URL=https://staging.together.xyz/api\n"
|
24 |
-
]
|
25 |
-
}
|
26 |
-
],
|
27 |
-
"source": [
|
28 |
-
"%env TOMA_URL=https://staging.together.xyz/api"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"attachments": {},
|
33 |
-
"cell_type": "markdown",
|
34 |
-
"id": "93a18ea6",
|
35 |
-
"metadata": {},
|
36 |
-
"source": [
|
37 |
-
"Make sure you have langchain installed and manifest. For the most recent versions, run\n",
|
38 |
-
"```\n",
|
39 |
-
"pip install git+https://github.com/hwchase17/langchain.git\n",
|
40 |
-
"pip install git+https://github.com/HazyResearch/manifest.git\n",
|
41 |
-
"```"
|
42 |
-
]
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"cell_type": "code",
|
46 |
-
"execution_count": 35,
|
47 |
-
"id": "a99acd89",
|
48 |
-
"metadata": {},
|
49 |
-
"outputs": [
|
50 |
-
{
|
51 |
-
"name": "stdout",
|
52 |
-
"output_type": "stream",
|
53 |
-
"text": [
|
54 |
-
"\n",
|
55 |
-
"\n",
|
56 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
57 |
-
"Prompt after formatting:\n",
|
58 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
59 |
-
"\n",
|
60 |
-
"\n",
|
61 |
-
"Input: Classes are \"positive\" and \"negative\". For example given\n",
|
62 |
-
"Input: I love this product!\n",
|
63 |
-
"Output: positive.\n",
|
64 |
-
"I think this movie was one of the worst of the year. Script was boring!\n",
|
65 |
-
"Output:\u001b[0m\n",
|
66 |
-
"\n",
|
67 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
68 |
-
"negative.\n"
|
69 |
-
]
|
70 |
-
}
|
71 |
-
],
|
72 |
-
"source": [
|
73 |
-
"from manifest import Manifest\n",
|
74 |
-
"from langchain.llms.manifest import ManifestWrapper\n",
|
75 |
-
"from langchain import ConversationChain, LLMChain, PromptTemplate\n",
|
76 |
-
"from langchain.chains.conversation.memory import ConversationalBufferWindowMemory\n",
|
77 |
-
"\n",
|
78 |
-
"\n",
|
79 |
-
"template = \"\"\"I am a classification model. It will try to classify your input.\n",
|
80 |
-
"\n",
|
81 |
-
"{history}\n",
|
82 |
-
"Input: {human_input}\n",
|
83 |
-
"Output:\"\"\"\n",
|
84 |
-
"\n",
|
85 |
-
"prompt = PromptTemplate(\n",
|
86 |
-
" input_variables=[\"history\", \"human_input\"], \n",
|
87 |
-
" template=template\n",
|
88 |
-
")\n",
|
89 |
-
"\n",
|
90 |
-
"manifest = Manifest(\n",
|
91 |
-
" client_name=\"toma\",\n",
|
92 |
-
" engine=\"Together-gpt-JT-6B-v1\",\n",
|
93 |
-
" max_tokens=150,\n",
|
94 |
-
" top_p=0.9,\n",
|
95 |
-
" top_k=40,\n",
|
96 |
-
" stop_sequences=[\"\\n\"],\n",
|
97 |
-
")\n",
|
98 |
-
"\n",
|
99 |
-
"chatgpt_chain = LLMChain(\n",
|
100 |
-
" llm=ManifestWrapper(client=manifest), \n",
|
101 |
-
" prompt=prompt, \n",
|
102 |
-
" verbose=True, \n",
|
103 |
-
" memory=ConversationalBufferWindowMemory(k=8),\n",
|
104 |
-
")\n",
|
105 |
-
"\n",
|
106 |
-
"output = chatgpt_chain.predict(human_input=\"Classes are \\\"positive\\\" and \\\"negative\\\". For example given\\nInput: I love this product!\\nOutput: positive.\\nI think this movie was one of the worst of the year. Script was boring!\")\n",
|
107 |
-
"print(output)"
|
108 |
-
]
|
109 |
-
},
|
110 |
-
{
|
111 |
-
"cell_type": "code",
|
112 |
-
"execution_count": 36,
|
113 |
-
"id": "4ef711d6",
|
114 |
-
"metadata": {},
|
115 |
-
"outputs": [
|
116 |
-
{
|
117 |
-
"name": "stdout",
|
118 |
-
"output_type": "stream",
|
119 |
-
"text": [
|
120 |
-
"\n",
|
121 |
-
"\n",
|
122 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
123 |
-
"Prompt after formatting:\n",
|
124 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
125 |
-
"\n",
|
126 |
-
"Human: Classes are \"positive\" and \"negative\". For example given\n",
|
127 |
-
"Input: I love this product!\n",
|
128 |
-
"Output: positive.\n",
|
129 |
-
"I think this movie was one of the worst of the year. Script was boring!\n",
|
130 |
-
"AI: negative.\n",
|
131 |
-
"Input: So awesome! I wish I could have gone\n",
|
132 |
-
"Output:\u001b[0m\n",
|
133 |
-
"\n",
|
134 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
135 |
-
"positive.\n"
|
136 |
-
]
|
137 |
-
}
|
138 |
-
],
|
139 |
-
"source": [
|
140 |
-
"output = chatgpt_chain.predict(human_input=\"So awesome! I wish I could have gone\")\n",
|
141 |
-
"print(output)"
|
142 |
-
]
|
143 |
-
},
|
144 |
-
{
|
145 |
-
"cell_type": "code",
|
146 |
-
"execution_count": 37,
|
147 |
-
"id": "a5d6dac2",
|
148 |
-
"metadata": {},
|
149 |
-
"outputs": [
|
150 |
-
{
|
151 |
-
"name": "stdout",
|
152 |
-
"output_type": "stream",
|
153 |
-
"text": [
|
154 |
-
"\n",
|
155 |
-
"\n",
|
156 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
157 |
-
"Prompt after formatting:\n",
|
158 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
159 |
-
"\n",
|
160 |
-
"Human: Classes are \"positive\" and \"negative\". For example given\n",
|
161 |
-
"Input: I love this product!\n",
|
162 |
-
"Output: positive.\n",
|
163 |
-
"I think this movie was one of the worst of the year. Script was boring!\n",
|
164 |
-
"AI: negative.\n",
|
165 |
-
"Human: So awesome! I wish I could have gone\n",
|
166 |
-
"AI: positive.\n",
|
167 |
-
"Input: Hate it.\n",
|
168 |
-
"Output:\u001b[0m\n",
|
169 |
-
"\n",
|
170 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
171 |
-
"negative.\n"
|
172 |
-
]
|
173 |
-
}
|
174 |
-
],
|
175 |
-
"source": [
|
176 |
-
"output = chatgpt_chain.predict(human_input=\"Hate it.\")\n",
|
177 |
-
"print(output)"
|
178 |
-
]
|
179 |
-
},
|
180 |
-
{
|
181 |
-
"cell_type": "code",
|
182 |
-
"execution_count": 43,
|
183 |
-
"id": "b9283077",
|
184 |
-
"metadata": {},
|
185 |
-
"outputs": [
|
186 |
-
{
|
187 |
-
"name": "stdout",
|
188 |
-
"output_type": "stream",
|
189 |
-
"text": [
|
190 |
-
"\n",
|
191 |
-
"\n",
|
192 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
193 |
-
"Prompt after formatting:\n",
|
194 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
195 |
-
"\n",
|
196 |
-
"\n",
|
197 |
-
"Input: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
|
198 |
-
"Input: This fruit rippens off of the tree.\n",
|
199 |
-
"Output: banana.\n",
|
200 |
-
"Often comes in bosc and bartlett varieties.\n",
|
201 |
-
"Output:\u001b[0m\n",
|
202 |
-
"\n",
|
203 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
204 |
-
"apple.\n"
|
205 |
-
]
|
206 |
-
}
|
207 |
-
],
|
208 |
-
"source": [
|
209 |
-
"chatgpt_chain.memory.clear()\n",
|
210 |
-
"output = chatgpt_chain.predict(human_input=\"Classes are fruits \\\"apple\\\", \\\"banana\\\", \\\"orange\\\", \\\"pear\\\". For example given\\nInput: This fruit rippens off of the tree.\\nOutput: banana.\\nOften comes in bosc and bartlett varieties.\")\n",
|
211 |
-
"print(output)"
|
212 |
-
]
|
213 |
-
},
|
214 |
-
{
|
215 |
-
"cell_type": "code",
|
216 |
-
"execution_count": 44,
|
217 |
-
"id": "cd0a23d9",
|
218 |
-
"metadata": {
|
219 |
-
"scrolled": true
|
220 |
-
},
|
221 |
-
"outputs": [
|
222 |
-
{
|
223 |
-
"name": "stdout",
|
224 |
-
"output_type": "stream",
|
225 |
-
"text": [
|
226 |
-
"\n",
|
227 |
-
"\n",
|
228 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
229 |
-
"Prompt after formatting:\n",
|
230 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
231 |
-
"\n",
|
232 |
-
"Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
|
233 |
-
"Input: This fruit rippens off of the tree.\n",
|
234 |
-
"Output: banana.\n",
|
235 |
-
"Often comes in bosc and bartlett varieties.\n",
|
236 |
-
"AI: apple.\n",
|
237 |
-
"Input: Often associated with monkeys\n",
|
238 |
-
"Output:\u001b[0m\n",
|
239 |
-
"\n",
|
240 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
241 |
-
"banana.\n"
|
242 |
-
]
|
243 |
-
}
|
244 |
-
],
|
245 |
-
"source": [
|
246 |
-
"output = chatgpt_chain.predict(human_input=\"Often associated with monkeys\")\n",
|
247 |
-
"print(output)"
|
248 |
-
]
|
249 |
-
},
|
250 |
-
{
|
251 |
-
"cell_type": "code",
|
252 |
-
"execution_count": 45,
|
253 |
-
"id": "90db6eb2",
|
254 |
-
"metadata": {},
|
255 |
-
"outputs": [
|
256 |
-
{
|
257 |
-
"name": "stdout",
|
258 |
-
"output_type": "stream",
|
259 |
-
"text": [
|
260 |
-
"\n",
|
261 |
-
"\n",
|
262 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
263 |
-
"Prompt after formatting:\n",
|
264 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
265 |
-
"\n",
|
266 |
-
"Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
|
267 |
-
"Input: This fruit rippens off of the tree.\n",
|
268 |
-
"Output: banana.\n",
|
269 |
-
"Often comes in bosc and bartlett varieties.\n",
|
270 |
-
"AI: apple.\n",
|
271 |
-
"Human: Often associated with monkeys\n",
|
272 |
-
"AI: banana.\n",
|
273 |
-
"Input: Is the color red and often delicious.\n",
|
274 |
-
"Output:\u001b[0m\n",
|
275 |
-
"\n",
|
276 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
277 |
-
"apple.\n"
|
278 |
-
]
|
279 |
-
}
|
280 |
-
],
|
281 |
-
"source": [
|
282 |
-
"output = chatgpt_chain.predict(human_input=\"Is the color red and often delicious.\")\n",
|
283 |
-
"print(output)"
|
284 |
-
]
|
285 |
-
},
|
286 |
-
{
|
287 |
-
"cell_type": "code",
|
288 |
-
"execution_count": 48,
|
289 |
-
"id": "c3806f89",
|
290 |
-
"metadata": {},
|
291 |
-
"outputs": [
|
292 |
-
{
|
293 |
-
"name": "stdout",
|
294 |
-
"output_type": "stream",
|
295 |
-
"text": [
|
296 |
-
"\n",
|
297 |
-
"\n",
|
298 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
299 |
-
"Prompt after formatting:\n",
|
300 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
301 |
-
"\n",
|
302 |
-
"\n",
|
303 |
-
"Input: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
|
304 |
-
"Input: The color of a school bus.\n",
|
305 |
-
"Output: yellow.\n",
|
306 |
-
"Is the color of the sky\n",
|
307 |
-
"Output:\u001b[0m\n",
|
308 |
-
"\n",
|
309 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
310 |
-
"blue.\n"
|
311 |
-
]
|
312 |
-
}
|
313 |
-
],
|
314 |
-
"source": [
|
315 |
-
"chatgpt_chain.memory.clear()\n",
|
316 |
-
"output = chatgpt_chain.predict(human_input=\"Classes are colors \\\"red\\\", \\\"green\\\", \\\"blue\\\", \\\"yellow\\\". For example given\\nInput: The color of a school bus.\\nOutput: yellow.\\nIs the color of the sky\")\n",
|
317 |
-
"print(output)"
|
318 |
-
]
|
319 |
-
},
|
320 |
-
{
|
321 |
-
"cell_type": "code",
|
322 |
-
"execution_count": 49,
|
323 |
-
"id": "f508f597",
|
324 |
-
"metadata": {},
|
325 |
-
"outputs": [
|
326 |
-
{
|
327 |
-
"name": "stdout",
|
328 |
-
"output_type": "stream",
|
329 |
-
"text": [
|
330 |
-
"\n",
|
331 |
-
"\n",
|
332 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
333 |
-
"Prompt after formatting:\n",
|
334 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
335 |
-
"\n",
|
336 |
-
"Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
|
337 |
-
"Input: The color of a school bus.\n",
|
338 |
-
"Output: yellow.\n",
|
339 |
-
"Is the color of the sky\n",
|
340 |
-
"AI: blue.\n",
|
341 |
-
"Input: Color of a banana.\n",
|
342 |
-
"Output:\u001b[0m\n",
|
343 |
-
"\n",
|
344 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
345 |
-
"yellow.\n"
|
346 |
-
]
|
347 |
-
}
|
348 |
-
],
|
349 |
-
"source": [
|
350 |
-
"output = chatgpt_chain.predict(human_input=\"Color of a banana.\")\n",
|
351 |
-
"print(output)"
|
352 |
-
]
|
353 |
-
},
|
354 |
-
{
|
355 |
-
"cell_type": "code",
|
356 |
-
"execution_count": 50,
|
357 |
-
"id": "cbd607f4",
|
358 |
-
"metadata": {},
|
359 |
-
"outputs": [
|
360 |
-
{
|
361 |
-
"name": "stdout",
|
362 |
-
"output_type": "stream",
|
363 |
-
"text": [
|
364 |
-
"\n",
|
365 |
-
"\n",
|
366 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
367 |
-
"Prompt after formatting:\n",
|
368 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
369 |
-
"\n",
|
370 |
-
"Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
|
371 |
-
"Input: The color of a school bus.\n",
|
372 |
-
"Output: yellow.\n",
|
373 |
-
"Is the color of the sky\n",
|
374 |
-
"AI: blue.\n",
|
375 |
-
"Human: Color of a banana.\n",
|
376 |
-
"AI: yellow.\n",
|
377 |
-
"Input: When someone is sick they are this color.\n",
|
378 |
-
"Output:\u001b[0m\n",
|
379 |
-
"\n",
|
380 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
381 |
-
"green.\n"
|
382 |
-
]
|
383 |
-
}
|
384 |
-
],
|
385 |
-
"source": [
|
386 |
-
"output = chatgpt_chain.predict(human_input=\"When someone is sick they are this color.\")\n",
|
387 |
-
"print(output)"
|
388 |
-
]
|
389 |
-
},
|
390 |
-
{
|
391 |
-
"cell_type": "code",
|
392 |
-
"execution_count": 51,
|
393 |
-
"id": "d33e0e28",
|
394 |
-
"metadata": {},
|
395 |
-
"outputs": [
|
396 |
-
{
|
397 |
-
"name": "stdout",
|
398 |
-
"output_type": "stream",
|
399 |
-
"text": [
|
400 |
-
"\n",
|
401 |
-
"\n",
|
402 |
-
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
|
403 |
-
"Prompt after formatting:\n",
|
404 |
-
"\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
|
405 |
-
"\n",
|
406 |
-
"Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
|
407 |
-
"Input: The color of a school bus.\n",
|
408 |
-
"Output: yellow.\n",
|
409 |
-
"Is the color of the sky\n",
|
410 |
-
"AI: blue.\n",
|
411 |
-
"Human: Color of a banana.\n",
|
412 |
-
"AI: yellow.\n",
|
413 |
-
"Human: When someone is sick they are this color.\n",
|
414 |
-
"AI: green.\n",
|
415 |
-
"Input: Color of anger.\n",
|
416 |
-
"Output:\u001b[0m\n",
|
417 |
-
"\n",
|
418 |
-
"\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
|
419 |
-
"red.\n"
|
420 |
-
]
|
421 |
-
}
|
422 |
-
],
|
423 |
-
"source": [
|
424 |
-
"output = chatgpt_chain.predict(human_input=\"Color of anger.\")\n",
|
425 |
-
"print(output)"
|
426 |
-
]
|
427 |
-
}
|
428 |
-
],
|
429 |
-
"metadata": {
|
430 |
-
"kernelspec": {
|
431 |
-
"display_name": "bootleg",
|
432 |
-
"language": "python",
|
433 |
-
"name": "python3"
|
434 |
-
},
|
435 |
-
"language_info": {
|
436 |
-
"codemirror_mode": {
|
437 |
-
"name": "ipython",
|
438 |
-
"version": 3
|
439 |
-
},
|
440 |
-
"file_extension": ".py",
|
441 |
-
"mimetype": "text/x-python",
|
442 |
-
"name": "python",
|
443 |
-
"nbconvert_exporter": "python",
|
444 |
-
"pygments_lexer": "ipython3",
|
445 |
-
"version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
|
446 |
-
},
|
447 |
-
"vscode": {
|
448 |
-
"interpreter": {
|
449 |
-
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
|
450 |
-
}
|
451 |
-
}
|
452 |
-
},
|
453 |
-
"nbformat": 4,
|
454 |
-
"nbformat_minor": 5
|
455 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_async.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
import time
|
3 |
-
|
4 |
-
from manifest import Manifest
|
5 |
-
|
6 |
-
|
7 |
-
def main():
|
8 |
-
|
9 |
-
manifest = Manifest(
|
10 |
-
client_name="openaichat",
|
11 |
-
)
|
12 |
-
|
13 |
-
print("Running in serial")
|
14 |
-
prompts = [f"Tell me something interesting about {i}" for i in range(50)]
|
15 |
-
st = time.time()
|
16 |
-
for pmt in prompts:
|
17 |
-
_ = manifest.run(pmt)
|
18 |
-
print(f"For loop: {time.time() - st :.2f}")
|
19 |
-
|
20 |
-
print("Running with async")
|
21 |
-
st = time.time()
|
22 |
-
_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30))
|
23 |
-
print(f"Async loop: {time.time() - st :.2f}")
|
24 |
-
|
25 |
-
|
26 |
-
if __name__ == "__main__":
|
27 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_azure.ipynb
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": null,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"AZURE_KEY = \"API_KEY::URL\"\n",
|
20 |
-
"OPENAI_KEY = \"sk-XXX\""
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"attachments": {},
|
25 |
-
"cell_type": "markdown",
|
26 |
-
"metadata": {},
|
27 |
-
"source": [
|
28 |
-
"## Use Azure and OpenAI models"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "code",
|
33 |
-
"execution_count": null,
|
34 |
-
"metadata": {},
|
35 |
-
"outputs": [],
|
36 |
-
"source": [
|
37 |
-
"from manifest import Manifest\n",
|
38 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
39 |
-
"from pathlib import Path\n",
|
40 |
-
"\n",
|
41 |
-
"cache_path = Path(\"manifest.db\")\n",
|
42 |
-
"if cache_path.exists():\n",
|
43 |
-
" cache_path.unlink()\n",
|
44 |
-
"\n",
|
45 |
-
"\n",
|
46 |
-
"azure = ClientConnection(\n",
|
47 |
-
" client_name=\"azureopenai\",\n",
|
48 |
-
" client_connection=AZURE_KEY,\n",
|
49 |
-
" engine=\"text-davinci-003\",\n",
|
50 |
-
")\n",
|
51 |
-
"\n",
|
52 |
-
"manifest = Manifest(client_pool=[azure], \n",
|
53 |
-
" cache_name=\"sqlite\",\n",
|
54 |
-
" cache_connection=\"manifest.db\"\n",
|
55 |
-
")\n",
|
56 |
-
"\n",
|
57 |
-
"\n",
|
58 |
-
"openai = ClientConnection(\n",
|
59 |
-
" client_name=\"openai\",\n",
|
60 |
-
" client_connection=OPENAI_KEY,\n",
|
61 |
-
" engine=\"text-davinci-003\",\n",
|
62 |
-
")\n",
|
63 |
-
"\n",
|
64 |
-
"manifest_openai_nocache = Manifest(client_pool=[openai])\n",
|
65 |
-
"\n",
|
66 |
-
"manifest_openai = Manifest(client_pool=[openai], \n",
|
67 |
-
" cache_name=\"sqlite\",\n",
|
68 |
-
" cache_connection=\"manifest.db\"\n",
|
69 |
-
")"
|
70 |
-
]
|
71 |
-
},
|
72 |
-
{
|
73 |
-
"cell_type": "code",
|
74 |
-
"execution_count": null,
|
75 |
-
"metadata": {},
|
76 |
-
"outputs": [],
|
77 |
-
"source": [
|
78 |
-
"# Show caches are the same\n",
|
79 |
-
"text = \"What is the meaning of life?\"\n",
|
80 |
-
"res = manifest.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
|
81 |
-
"print(res.get_response())\n",
|
82 |
-
"print(res.is_cached())\n",
|
83 |
-
"res2 = manifest_openai.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
|
84 |
-
"print(res2.is_cached())\n",
|
85 |
-
"\n",
|
86 |
-
"assert res2.get_response() == res.get_response()"
|
87 |
-
]
|
88 |
-
},
|
89 |
-
{
|
90 |
-
"cell_type": "code",
|
91 |
-
"execution_count": null,
|
92 |
-
"metadata": {},
|
93 |
-
"outputs": [],
|
94 |
-
"source": [
|
95 |
-
"azure_chat = ClientConnection(\n",
|
96 |
-
" client_name=\"azureopenaichat\",\n",
|
97 |
-
" client_connection=AZURE_KEY,\n",
|
98 |
-
" engine=\"gpt-3.5-turbo\",\n",
|
99 |
-
")\n",
|
100 |
-
"\n",
|
101 |
-
"manifest = Manifest(client_pool=[azure_chat])"
|
102 |
-
]
|
103 |
-
},
|
104 |
-
{
|
105 |
-
"cell_type": "code",
|
106 |
-
"execution_count": null,
|
107 |
-
"metadata": {},
|
108 |
-
"outputs": [],
|
109 |
-
"source": [
|
110 |
-
"print(manifest.run(\"What do you think is the best food?\", max_tokens=100))\n",
|
111 |
-
"\n",
|
112 |
-
"chat_dict = [\n",
|
113 |
-
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
114 |
-
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
|
115 |
-
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
|
116 |
-
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
|
117 |
-
"]\n",
|
118 |
-
"print(manifest.run(chat_dict, max_tokens=100))"
|
119 |
-
]
|
120 |
-
}
|
121 |
-
],
|
122 |
-
"metadata": {
|
123 |
-
"kernelspec": {
|
124 |
-
"display_name": "manifest",
|
125 |
-
"language": "python",
|
126 |
-
"name": "python3"
|
127 |
-
},
|
128 |
-
"language_info": {
|
129 |
-
"codemirror_mode": {
|
130 |
-
"name": "ipython",
|
131 |
-
"version": 3
|
132 |
-
},
|
133 |
-
"file_extension": ".py",
|
134 |
-
"mimetype": "text/x-python",
|
135 |
-
"name": "python",
|
136 |
-
"nbconvert_exporter": "python",
|
137 |
-
"pygments_lexer": "ipython3",
|
138 |
-
"version": "3.10.4"
|
139 |
-
},
|
140 |
-
"orig_nbformat": 4,
|
141 |
-
"vscode": {
|
142 |
-
"interpreter": {
|
143 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
144 |
-
}
|
145 |
-
}
|
146 |
-
},
|
147 |
-
"nbformat": 4,
|
148 |
-
"nbformat_minor": 2
|
149 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_chatgpt.ipynb
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": null,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"OPENAI_KEY = \"sk-XXX\""
|
20 |
-
]
|
21 |
-
},
|
22 |
-
{
|
23 |
-
"attachments": {},
|
24 |
-
"cell_type": "markdown",
|
25 |
-
"metadata": {},
|
26 |
-
"source": [
|
27 |
-
"## Use ChatOpenAI\n",
|
28 |
-
"\n",
|
29 |
-
"Set you `OPENAI_API_KEY` environment variable."
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "code",
|
34 |
-
"execution_count": null,
|
35 |
-
"metadata": {},
|
36 |
-
"outputs": [],
|
37 |
-
"source": [
|
38 |
-
"from manifest import Manifest\n",
|
39 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
40 |
-
"\n",
|
41 |
-
"openai_chat = ClientConnection(\n",
|
42 |
-
" client_name=\"openaichat\",\n",
|
43 |
-
" client_connection=OPENAI_KEY,\n",
|
44 |
-
" engine=\"gpt-3.5-turbo\"\n",
|
45 |
-
")\n",
|
46 |
-
"\n",
|
47 |
-
"manifest = Manifest(client_pool=[openai_chat])"
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"cell_type": "code",
|
52 |
-
"execution_count": null,
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"# Simple question\n",
|
57 |
-
"chat_dict = [\n",
|
58 |
-
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
59 |
-
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
|
60 |
-
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
|
61 |
-
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
|
62 |
-
"]\n",
|
63 |
-
"print(manifest.run(chat_dict, max_tokens=100))"
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"cell_type": "code",
|
68 |
-
"execution_count": null,
|
69 |
-
"metadata": {},
|
70 |
-
"outputs": [],
|
71 |
-
"source": []
|
72 |
-
}
|
73 |
-
],
|
74 |
-
"metadata": {
|
75 |
-
"kernelspec": {
|
76 |
-
"display_name": "manifest",
|
77 |
-
"language": "python",
|
78 |
-
"name": "python3"
|
79 |
-
},
|
80 |
-
"language_info": {
|
81 |
-
"codemirror_mode": {
|
82 |
-
"name": "ipython",
|
83 |
-
"version": 3
|
84 |
-
},
|
85 |
-
"file_extension": ".py",
|
86 |
-
"mimetype": "text/x-python",
|
87 |
-
"name": "python",
|
88 |
-
"nbconvert_exporter": "python",
|
89 |
-
"pygments_lexer": "ipython3",
|
90 |
-
"version": "3.10.4"
|
91 |
-
},
|
92 |
-
"orig_nbformat": 4,
|
93 |
-
"vscode": {
|
94 |
-
"interpreter": {
|
95 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
96 |
-
}
|
97 |
-
}
|
98 |
-
},
|
99 |
-
"nbformat": 4,
|
100 |
-
"nbformat_minor": 2
|
101 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_connection_pool.ipynb
DELETED
@@ -1,208 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 2,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": 1,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"OPENAI_KEY1 = \"sk-XXX\"\n",
|
20 |
-
"OPENAI_KEY2 = \"sk-XX\""
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"attachments": {},
|
25 |
-
"cell_type": "markdown",
|
26 |
-
"metadata": {},
|
27 |
-
"source": [
|
28 |
-
"## Use OpenAI\n",
|
29 |
-
"\n",
|
30 |
-
"Set you `OPENAI_API_KEY` environment variable."
|
31 |
-
]
|
32 |
-
},
|
33 |
-
{
|
34 |
-
"cell_type": "code",
|
35 |
-
"execution_count": 2,
|
36 |
-
"metadata": {},
|
37 |
-
"outputs": [],
|
38 |
-
"source": [
|
39 |
-
"from manifest import Manifest\n",
|
40 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
41 |
-
"\n",
|
42 |
-
"openai_ada = ClientConnection(\n",
|
43 |
-
" client_name=\"openai\",\n",
|
44 |
-
" client_connection=OPENAI_KEY1,\n",
|
45 |
-
" engine=\"text-ada-001\"\n",
|
46 |
-
")\n",
|
47 |
-
"\n",
|
48 |
-
"openai_curie = ClientConnection(\n",
|
49 |
-
" client_name=\"openai\",\n",
|
50 |
-
" client_connection=OPENAI_KEY2,\n",
|
51 |
-
" engine=\"text-curie-001\"\n",
|
52 |
-
")\n",
|
53 |
-
"\n",
|
54 |
-
"manifest = Manifest(client_pool=[openai_ada, openai_curie], client_pool_schedule=\"round_robin\")"
|
55 |
-
]
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"cell_type": "code",
|
59 |
-
"execution_count": 3,
|
60 |
-
"metadata": {},
|
61 |
-
"outputs": [
|
62 |
-
{
|
63 |
-
"name": "stdout",
|
64 |
-
"output_type": "stream",
|
65 |
-
"text": [
|
66 |
-
"0\n",
|
67 |
-
"I am a model.\n",
|
68 |
-
"1\n",
|
69 |
-
"I am a MacBook Pro with a retina\n"
|
70 |
-
]
|
71 |
-
}
|
72 |
-
],
|
73 |
-
"source": [
|
74 |
-
"res = manifest.run(\"What model are you?\", temperature=0.0)\n",
|
75 |
-
"print(manifest.client_pool.current_client_id)\n",
|
76 |
-
"print(res)\n",
|
77 |
-
"res = manifest.run(\"What model are you?\", temperature=0.0)\n",
|
78 |
-
"print(manifest.client_pool.current_client_id)\n",
|
79 |
-
"print(res)"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"attachments": {},
|
84 |
-
"cell_type": "markdown",
|
85 |
-
"metadata": {},
|
86 |
-
"source": [
|
87 |
-
"## With Async"
|
88 |
-
]
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"cell_type": "code",
|
92 |
-
"execution_count": 4,
|
93 |
-
"metadata": {},
|
94 |
-
"outputs": [],
|
95 |
-
"source": [
|
96 |
-
"import nest_asyncio\n",
|
97 |
-
"# This is required for asyncio.run(...) to work in Jupyter notebooks.\n",
|
98 |
-
"nest_asyncio.apply()"
|
99 |
-
]
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "code",
|
103 |
-
"execution_count": 5,
|
104 |
-
"metadata": {},
|
105 |
-
"outputs": [],
|
106 |
-
"source": [
|
107 |
-
"from manifest import Manifest\n",
|
108 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
109 |
-
"\n",
|
110 |
-
"openai_ada = ClientConnection(\n",
|
111 |
-
" client_name=\"openai\",\n",
|
112 |
-
" client_connection=OPENAI_KEY1,\n",
|
113 |
-
" engine=\"text-ada-001\"\n",
|
114 |
-
")\n",
|
115 |
-
"\n",
|
116 |
-
"openai_babbage = ClientConnection(\n",
|
117 |
-
" client_name=\"openai\",\n",
|
118 |
-
" client_connection=OPENAI_KEY2,\n",
|
119 |
-
" engine=\"text-babbage-001\"\n",
|
120 |
-
")\n",
|
121 |
-
"\n",
|
122 |
-
"openai_curie = ClientConnection(\n",
|
123 |
-
" client_name=\"openai\",\n",
|
124 |
-
" client_connection=OPENAI_KEY2,\n",
|
125 |
-
" engine=\"text-curie-001\"\n",
|
126 |
-
")\n",
|
127 |
-
"\n",
|
128 |
-
"manifest = Manifest(client_pool=[openai_ada, openai_babbage, openai_curie], client_pool_schedule=\"round_robin\")\n",
|
129 |
-
"manifest_single_client = Manifest(client_pool=[openai_babbage])"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "code",
|
134 |
-
"execution_count": 6,
|
135 |
-
"metadata": {},
|
136 |
-
"outputs": [
|
137 |
-
{
|
138 |
-
"name": "stdout",
|
139 |
-
"output_type": "stream",
|
140 |
-
"text": [
|
141 |
-
"For loop: 128.68\n",
|
142 |
-
"Running with async single client\n",
|
143 |
-
"Running 1 tasks across all clients.\n",
|
144 |
-
"Async loop: 4.02\n",
|
145 |
-
"Running with async two clients but not chunking\n",
|
146 |
-
"Running 1 tasks across all clients.\n",
|
147 |
-
"Async loop: 3.92\n",
|
148 |
-
"Running with async two clients and chunk size\n",
|
149 |
-
"Running 20 tasks across all clients.\n",
|
150 |
-
"Async loop: 1.44\n"
|
151 |
-
]
|
152 |
-
}
|
153 |
-
],
|
154 |
-
"source": [
|
155 |
-
"import time\n",
|
156 |
-
"import asyncio\n",
|
157 |
-
"\n",
|
158 |
-
"prompts = [f\"Tell me something interesting about {i}\" for i in range(400)]\n",
|
159 |
-
"st = time.time()\n",
|
160 |
-
"for pmt in prompts:\n",
|
161 |
-
" _ = manifest_single_client.run(pmt, max_tokens=30)\n",
|
162 |
-
"print(f\"For loop: {time.time() - st :.2f}\")\n",
|
163 |
-
"\n",
|
164 |
-
"print(\"Running with async single client\")\n",
|
165 |
-
"st = time.time()\n",
|
166 |
-
"_ = asyncio.run(manifest_single_client.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
|
167 |
-
"print(f\"Async loop: {time.time() - st :.2f}\")\n",
|
168 |
-
"\n",
|
169 |
-
"print(\"Running with async two clients but not chunking\")\n",
|
170 |
-
"st = time.time()\n",
|
171 |
-
"_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
|
172 |
-
"print(f\"Async loop: {time.time() - st :.2f}\")\n",
|
173 |
-
"\n",
|
174 |
-
"print(\"Running with async two clients and chunk size\")\n",
|
175 |
-
"st = time.time()\n",
|
176 |
-
"_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))\n",
|
177 |
-
"print(f\"Async loop: {time.time() - st :.2f}\")"
|
178 |
-
]
|
179 |
-
}
|
180 |
-
],
|
181 |
-
"metadata": {
|
182 |
-
"kernelspec": {
|
183 |
-
"display_name": "manifest",
|
184 |
-
"language": "python",
|
185 |
-
"name": "python3"
|
186 |
-
},
|
187 |
-
"language_info": {
|
188 |
-
"codemirror_mode": {
|
189 |
-
"name": "ipython",
|
190 |
-
"version": 3
|
191 |
-
},
|
192 |
-
"file_extension": ".py",
|
193 |
-
"mimetype": "text/x-python",
|
194 |
-
"name": "python",
|
195 |
-
"nbconvert_exporter": "python",
|
196 |
-
"pygments_lexer": "ipython3",
|
197 |
-
"version": "3.10.4"
|
198 |
-
},
|
199 |
-
"orig_nbformat": 4,
|
200 |
-
"vscode": {
|
201 |
-
"interpreter": {
|
202 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
203 |
-
}
|
204 |
-
}
|
205 |
-
},
|
206 |
-
"nbformat": 4,
|
207 |
-
"nbformat_minor": 2
|
208 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_diffusers.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
duckdb-nsql/manifest/examples/manifest_embedding.ipynb
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"attachments": {},
|
15 |
-
"cell_type": "markdown",
|
16 |
-
"metadata": {},
|
17 |
-
"source": [
|
18 |
-
"## Use OpenAI\n",
|
19 |
-
"\n",
|
20 |
-
"Set you `OPENAI_API_KEY` environment variable."
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": 2,
|
26 |
-
"metadata": {},
|
27 |
-
"outputs": [
|
28 |
-
{
|
29 |
-
"name": "stdout",
|
30 |
-
"output_type": "stream",
|
31 |
-
"text": [
|
32 |
-
"{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
|
33 |
-
]
|
34 |
-
}
|
35 |
-
],
|
36 |
-
"source": [
|
37 |
-
"from manifest import Manifest\n",
|
38 |
-
"\n",
|
39 |
-
"manifest = Manifest(client_name=\"openaiembedding\")\n",
|
40 |
-
"print(manifest.client_pool.get_next_client().get_model_params())"
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": 3,
|
46 |
-
"metadata": {},
|
47 |
-
"outputs": [
|
48 |
-
{
|
49 |
-
"name": "stdout",
|
50 |
-
"output_type": "stream",
|
51 |
-
"text": [
|
52 |
-
"(1536,)\n"
|
53 |
-
]
|
54 |
-
}
|
55 |
-
],
|
56 |
-
"source": [
|
57 |
-
"emb = manifest.run(\"Is this an embedding?\")\n",
|
58 |
-
"print(emb.shape)"
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"attachments": {},
|
63 |
-
"cell_type": "markdown",
|
64 |
-
"metadata": {},
|
65 |
-
"source": [
|
66 |
-
"### Using Locally Hosted Huggingface LM\n",
|
67 |
-
"\n",
|
68 |
-
"Run\n",
|
69 |
-
"```\n",
|
70 |
-
"python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
|
71 |
-
"```\n",
|
72 |
-
"or\n",
|
73 |
-
"```\n",
|
74 |
-
"python3 manifest/api/app.py --model_type sentence_transformers --model_name_or_path all-mpnet-base-v2 --device 0\n",
|
75 |
-
"```\n",
|
76 |
-
"\n",
|
77 |
-
"in a separate `screen` or `tmux`. Make sure to note the port. You can change this with `export FLASK_PORT=<port>`."
|
78 |
-
]
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"cell_type": "code",
|
82 |
-
"execution_count": 1,
|
83 |
-
"metadata": {},
|
84 |
-
"outputs": [
|
85 |
-
{
|
86 |
-
"name": "stdout",
|
87 |
-
"output_type": "stream",
|
88 |
-
"text": [
|
89 |
-
"{'model_name': 'all-mpnet-base-v2', 'model_path': 'all-mpnet-base-v2', 'client_name': 'huggingfaceembedding'}\n"
|
90 |
-
]
|
91 |
-
}
|
92 |
-
],
|
93 |
-
"source": [
|
94 |
-
"from manifest import Manifest\n",
|
95 |
-
"\n",
|
96 |
-
"# Local hosted GPT Neo 125M\n",
|
97 |
-
"manifest = Manifest(\n",
|
98 |
-
" client_name=\"huggingfaceembedding\",\n",
|
99 |
-
" client_connection=\"http://127.0.0.1:6000\",\n",
|
100 |
-
" cache_name=\"sqlite\",\n",
|
101 |
-
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
|
102 |
-
")\n",
|
103 |
-
"print(manifest.client_pool.get_next_client().get_model_params())"
|
104 |
-
]
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"cell_type": "code",
|
108 |
-
"execution_count": 4,
|
109 |
-
"metadata": {},
|
110 |
-
"outputs": [
|
111 |
-
{
|
112 |
-
"name": "stdout",
|
113 |
-
"output_type": "stream",
|
114 |
-
"text": [
|
115 |
-
"(768,)\n",
|
116 |
-
"(768,) (768,)\n"
|
117 |
-
]
|
118 |
-
}
|
119 |
-
],
|
120 |
-
"source": [
|
121 |
-
"emb = manifest.run(\"Is this an embedding?\")\n",
|
122 |
-
"print(emb.shape)\n",
|
123 |
-
"\n",
|
124 |
-
"emb = manifest.run([\"Is this an embedding?\", \"Bananas!!!\"])\n",
|
125 |
-
"print(emb[0].shape, emb[1].shape)"
|
126 |
-
]
|
127 |
-
}
|
128 |
-
],
|
129 |
-
"metadata": {
|
130 |
-
"kernelspec": {
|
131 |
-
"display_name": "manifest",
|
132 |
-
"language": "python",
|
133 |
-
"name": "python3"
|
134 |
-
},
|
135 |
-
"language_info": {
|
136 |
-
"codemirror_mode": {
|
137 |
-
"name": "ipython",
|
138 |
-
"version": 3
|
139 |
-
},
|
140 |
-
"file_extension": ".py",
|
141 |
-
"mimetype": "text/x-python",
|
142 |
-
"name": "python",
|
143 |
-
"nbconvert_exporter": "python",
|
144 |
-
"pygments_lexer": "ipython3",
|
145 |
-
"version": "3.10.4"
|
146 |
-
},
|
147 |
-
"orig_nbformat": 4,
|
148 |
-
"vscode": {
|
149 |
-
"interpreter": {
|
150 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
151 |
-
}
|
152 |
-
}
|
153 |
-
},
|
154 |
-
"nbformat": 4,
|
155 |
-
"nbformat_minor": 2
|
156 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_google.ipynb
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": null,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"GOOGLE_KEY = \"KEY::PROJECT_ID\""
|
20 |
-
]
|
21 |
-
},
|
22 |
-
{
|
23 |
-
"attachments": {},
|
24 |
-
"cell_type": "markdown",
|
25 |
-
"metadata": {},
|
26 |
-
"source": [
|
27 |
-
"## Use GoogleVertexAPI"
|
28 |
-
]
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"cell_type": "code",
|
32 |
-
"execution_count": null,
|
33 |
-
"metadata": {},
|
34 |
-
"outputs": [],
|
35 |
-
"source": [
|
36 |
-
"from manifest import Manifest\n",
|
37 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
38 |
-
"\n",
|
39 |
-
"google_bison = ClientConnection(\n",
|
40 |
-
" client_name=\"google\",\n",
|
41 |
-
" client_connection=GOOGLE_KEY\n",
|
42 |
-
")\n",
|
43 |
-
"\n",
|
44 |
-
"manifest = Manifest(client_pool=[google_bison])"
|
45 |
-
]
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"cell_type": "code",
|
49 |
-
"execution_count": null,
|
50 |
-
"metadata": {},
|
51 |
-
"outputs": [],
|
52 |
-
"source": [
|
53 |
-
"# Simple question\n",
|
54 |
-
"print(manifest.run(\"What is your name\", max_tokens=40))"
|
55 |
-
]
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"cell_type": "code",
|
59 |
-
"execution_count": null,
|
60 |
-
"metadata": {},
|
61 |
-
"outputs": [],
|
62 |
-
"source": [
|
63 |
-
"from manifest import Manifest\n",
|
64 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
65 |
-
"\n",
|
66 |
-
"google_bison = ClientConnection(\n",
|
67 |
-
" client_name=\"googlechat\",\n",
|
68 |
-
" client_connection=GOOGLE_KEY\n",
|
69 |
-
")\n",
|
70 |
-
"\n",
|
71 |
-
"manifest = Manifest(client_pool=[google_bison])"
|
72 |
-
]
|
73 |
-
},
|
74 |
-
{
|
75 |
-
"cell_type": "code",
|
76 |
-
"execution_count": null,
|
77 |
-
"metadata": {},
|
78 |
-
"outputs": [],
|
79 |
-
"source": [
|
80 |
-
"chat_dict = [\n",
|
81 |
-
" # {\"author\": \"bot\", \"content\": \"You are a helpful assistant.\"},\n",
|
82 |
-
" {\"author\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
|
83 |
-
" {\"author\": \"bot\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
|
84 |
-
" {\"author\": \"user\", \"content\": \"Where was it played?\"}\n",
|
85 |
-
"]\n",
|
86 |
-
"print(manifest.run(chat_dict, max_tokens=8))"
|
87 |
-
]
|
88 |
-
}
|
89 |
-
],
|
90 |
-
"metadata": {
|
91 |
-
"kernelspec": {
|
92 |
-
"display_name": "manifest",
|
93 |
-
"language": "python",
|
94 |
-
"name": "python3"
|
95 |
-
},
|
96 |
-
"language_info": {
|
97 |
-
"codemirror_mode": {
|
98 |
-
"name": "ipython",
|
99 |
-
"version": 3
|
100 |
-
},
|
101 |
-
"file_extension": ".py",
|
102 |
-
"mimetype": "text/x-python",
|
103 |
-
"name": "python",
|
104 |
-
"nbconvert_exporter": "python",
|
105 |
-
"pygments_lexer": "ipython3",
|
106 |
-
"version": "3.10.4"
|
107 |
-
},
|
108 |
-
"orig_nbformat": 4,
|
109 |
-
"vscode": {
|
110 |
-
"interpreter": {
|
111 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
112 |
-
}
|
113 |
-
}
|
114 |
-
},
|
115 |
-
"nbformat": 4,
|
116 |
-
"nbformat_minor": 2
|
117 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_openrouter.ipynb
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": 4,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"OPENROUTER_API_KEY = \"sk-...\""
|
20 |
-
]
|
21 |
-
},
|
22 |
-
{
|
23 |
-
"attachments": {},
|
24 |
-
"cell_type": "markdown",
|
25 |
-
"metadata": {},
|
26 |
-
"source": [
|
27 |
-
"## Use ChatOpenAI\n",
|
28 |
-
"\n",
|
29 |
-
"Set you `OPENROUTER_API_KEY` environment variable."
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "code",
|
34 |
-
"execution_count": 5,
|
35 |
-
"metadata": {},
|
36 |
-
"outputs": [],
|
37 |
-
"source": [
|
38 |
-
"from manifest import Manifest\n",
|
39 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
40 |
-
"\n",
|
41 |
-
"openai_chat = ClientConnection(\n",
|
42 |
-
" client_name=\"openrouter\",\n",
|
43 |
-
" client_connection=OPENROUTER_API_KEY,\n",
|
44 |
-
" engine=\"meta-llama/codellama-70b-instruct\"\n",
|
45 |
-
")\n",
|
46 |
-
"\n",
|
47 |
-
"manifest = Manifest(client_pool=[openai_chat])"
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"cell_type": "code",
|
52 |
-
"execution_count": 6,
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [
|
55 |
-
{
|
56 |
-
"name": "stdout",
|
57 |
-
"output_type": "stream",
|
58 |
-
"text": [
|
59 |
-
"2020 World Series was played at the Globe Life Field in Arlington, Texas.\n"
|
60 |
-
]
|
61 |
-
}
|
62 |
-
],
|
63 |
-
"source": [
|
64 |
-
"# Simple question\n",
|
65 |
-
"chat_dict = [\n",
|
66 |
-
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
67 |
-
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
|
68 |
-
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
|
69 |
-
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
|
70 |
-
"]\n",
|
71 |
-
"print(manifest.run(chat_dict, max_tokens=100))"
|
72 |
-
]
|
73 |
-
},
|
74 |
-
{
|
75 |
-
"cell_type": "code",
|
76 |
-
"execution_count": null,
|
77 |
-
"metadata": {},
|
78 |
-
"outputs": [],
|
79 |
-
"source": []
|
80 |
-
}
|
81 |
-
],
|
82 |
-
"metadata": {
|
83 |
-
"kernelspec": {
|
84 |
-
"display_name": "Python 3 (ipykernel)",
|
85 |
-
"language": "python",
|
86 |
-
"name": "python3"
|
87 |
-
},
|
88 |
-
"language_info": {
|
89 |
-
"codemirror_mode": {
|
90 |
-
"name": "ipython",
|
91 |
-
"version": 3
|
92 |
-
},
|
93 |
-
"file_extension": ".py",
|
94 |
-
"mimetype": "text/x-python",
|
95 |
-
"name": "python",
|
96 |
-
"nbconvert_exporter": "python",
|
97 |
-
"pygments_lexer": "ipython3",
|
98 |
-
"version": "3.11.5"
|
99 |
-
},
|
100 |
-
"vscode": {
|
101 |
-
"interpreter": {
|
102 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
103 |
-
}
|
104 |
-
}
|
105 |
-
},
|
106 |
-
"nbformat": 4,
|
107 |
-
"nbformat_minor": 4
|
108 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_streaming.ipynb
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": null,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"OPENAI_KEY = \"sk-XXX\""
|
20 |
-
]
|
21 |
-
},
|
22 |
-
{
|
23 |
-
"attachments": {},
|
24 |
-
"cell_type": "markdown",
|
25 |
-
"metadata": {},
|
26 |
-
"source": [
|
27 |
-
"## Use ChatOpenAI\n",
|
28 |
-
"\n",
|
29 |
-
"Set you `OPENAI_API_KEY` environment variable."
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "code",
|
34 |
-
"execution_count": null,
|
35 |
-
"metadata": {},
|
36 |
-
"outputs": [],
|
37 |
-
"source": [
|
38 |
-
"from manifest import Manifest\n",
|
39 |
-
"from manifest.connections.client_pool import ClientConnection\n",
|
40 |
-
"\n",
|
41 |
-
"openai_chat = ClientConnection(\n",
|
42 |
-
" client_name=\"openaichat\",\n",
|
43 |
-
" client_connection=OPENAI_KEY,\n",
|
44 |
-
" engine=\"gpt-3.5-turbo\"\n",
|
45 |
-
")\n",
|
46 |
-
"\n",
|
47 |
-
"manifest = Manifest(client_pool=[openai_chat])"
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"cell_type": "code",
|
52 |
-
"execution_count": null,
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"manifest_iterator = manifest.run(\"Tell me a story about a fat cat.\\n\\nOnce upon a time\", max_tokens=200, stream=True)"
|
57 |
-
]
|
58 |
-
},
|
59 |
-
{
|
60 |
-
"cell_type": "code",
|
61 |
-
"execution_count": null,
|
62 |
-
"metadata": {},
|
63 |
-
"outputs": [],
|
64 |
-
"source": [
|
65 |
-
"import sys\n",
|
66 |
-
"\n",
|
67 |
-
"cur_line_length = 0\n",
|
68 |
-
"# Iterate over stream\n",
|
69 |
-
"for res in manifest_iterator:\n",
|
70 |
-
" sys.stdout.write(res)\n",
|
71 |
-
" cur_line_length += len(res)\n",
|
72 |
-
" if cur_line_length > 80:\n",
|
73 |
-
" sys.stdout.write(\"\\n\")\n",
|
74 |
-
" cur_line_length = 0"
|
75 |
-
]
|
76 |
-
}
|
77 |
-
],
|
78 |
-
"metadata": {
|
79 |
-
"kernelspec": {
|
80 |
-
"display_name": "manifest",
|
81 |
-
"language": "python",
|
82 |
-
"name": "python3"
|
83 |
-
},
|
84 |
-
"language_info": {
|
85 |
-
"codemirror_mode": {
|
86 |
-
"name": "ipython",
|
87 |
-
"version": 3
|
88 |
-
},
|
89 |
-
"file_extension": ".py",
|
90 |
-
"mimetype": "text/x-python",
|
91 |
-
"name": "python",
|
92 |
-
"nbconvert_exporter": "python",
|
93 |
-
"pygments_lexer": "ipython3",
|
94 |
-
"version": "3.10.4"
|
95 |
-
},
|
96 |
-
"orig_nbformat": 4,
|
97 |
-
"vscode": {
|
98 |
-
"interpreter": {
|
99 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
100 |
-
}
|
101 |
-
}
|
102 |
-
},
|
103 |
-
"nbformat": 4,
|
104 |
-
"nbformat_minor": 2
|
105 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/examples/manifest_together.ipynb
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stdout",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"env: TOMA_URL=<TOMA_URL>\n"
|
13 |
-
]
|
14 |
-
}
|
15 |
-
],
|
16 |
-
"source": [
|
17 |
-
"%load_ext autoreload\n",
|
18 |
-
"%autoreload 2\n",
|
19 |
-
"\n",
|
20 |
-
"%env TOMA_URL=<TOMA_URL>"
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": null,
|
26 |
-
"metadata": {},
|
27 |
-
"outputs": [],
|
28 |
-
"source": [
|
29 |
-
"from manifest import Manifest\n",
|
30 |
-
"\n",
|
31 |
-
"# The responses are not fast\n",
|
32 |
-
"manifest = Manifest(\n",
|
33 |
-
" client_name=\"toma\",\n",
|
34 |
-
")\n",
|
35 |
-
"\n",
|
36 |
-
"print(manifest.run(\"What is the color of an apple?\"))"
|
37 |
-
]
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"attachments": {},
|
41 |
-
"cell_type": "markdown",
|
42 |
-
"metadata": {},
|
43 |
-
"source": [
|
44 |
-
"With a cache"
|
45 |
-
]
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"cell_type": "code",
|
49 |
-
"execution_count": null,
|
50 |
-
"metadata": {},
|
51 |
-
"outputs": [],
|
52 |
-
"source": [
|
53 |
-
"from manifest import Manifest\n",
|
54 |
-
"\n",
|
55 |
-
"# The responses are not fast\n",
|
56 |
-
"manifest = Manifest(\n",
|
57 |
-
" client_name=\"toma\",\n",
|
58 |
-
" cache_name=\"sqlite\",\n",
|
59 |
-
" cache_connection=\"my_manifest_cache.sqlite\",\n",
|
60 |
-
")\n",
|
61 |
-
"\n",
|
62 |
-
"res = manifest.run(\"What is the color of an apple?\", return_response=True)\n",
|
63 |
-
"print(res.get_response())\n",
|
64 |
-
"print(\"Is Cached?\", res.is_cached())\n",
|
65 |
-
"\n",
|
66 |
-
"res = manifest.run(\"What is the color of an apple?\", return_response=True)\n",
|
67 |
-
"print(res.get_response())\n",
|
68 |
-
"print(\"Is Cached?\", res.is_cached())"
|
69 |
-
]
|
70 |
-
},
|
71 |
-
{
|
72 |
-
"cell_type": "code",
|
73 |
-
"execution_count": null,
|
74 |
-
"metadata": {},
|
75 |
-
"outputs": [],
|
76 |
-
"source": []
|
77 |
-
}
|
78 |
-
],
|
79 |
-
"metadata": {
|
80 |
-
"kernelspec": {
|
81 |
-
"display_name": "manifest",
|
82 |
-
"language": "python",
|
83 |
-
"name": "python3"
|
84 |
-
},
|
85 |
-
"language_info": {
|
86 |
-
"codemirror_mode": {
|
87 |
-
"name": "ipython",
|
88 |
-
"version": 3
|
89 |
-
},
|
90 |
-
"file_extension": ".py",
|
91 |
-
"mimetype": "text/x-python",
|
92 |
-
"name": "python",
|
93 |
-
"nbconvert_exporter": "python",
|
94 |
-
"pygments_lexer": "ipython3",
|
95 |
-
"version": "3.10.4"
|
96 |
-
},
|
97 |
-
"orig_nbformat": 4,
|
98 |
-
"vscode": {
|
99 |
-
"interpreter": {
|
100 |
-
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
|
101 |
-
}
|
102 |
-
}
|
103 |
-
},
|
104 |
-
"nbformat": 4,
|
105 |
-
"nbformat_minor": 2
|
106 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
"""Manifest init."""
|
2 |
-
from manifest.manifest import Manifest
|
3 |
-
from manifest.request import Request
|
4 |
-
from manifest.response import Response
|
5 |
-
|
6 |
-
__all__ = ["Manifest", "Response", "Request"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""Api init."""
|
|
|
|
duckdb-nsql/manifest/manifest/api/app.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
"""Flask app."""
|
2 |
-
import argparse
|
3 |
-
import io
|
4 |
-
import json
|
5 |
-
import logging
|
6 |
-
import os
|
7 |
-
import socket
|
8 |
-
from typing import Dict
|
9 |
-
|
10 |
-
import pkg_resources
|
11 |
-
from flask import Flask, Response, request
|
12 |
-
|
13 |
-
from manifest.api.models.diffuser import DiffuserModel
|
14 |
-
from manifest.api.models.huggingface import (
|
15 |
-
MODEL_GENTYPE_REGISTRY,
|
16 |
-
CrossModalEncoderModel,
|
17 |
-
TextGenerationModel,
|
18 |
-
)
|
19 |
-
from manifest.api.models.sentence_transformer import SentenceTransformerModel
|
20 |
-
from manifest.api.response import ModelResponse
|
21 |
-
|
22 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
-
|
24 |
-
logger = logging.getLogger(__name__)
|
25 |
-
app = Flask(__name__) # define app using Flask
|
26 |
-
# Will be global
|
27 |
-
model = None
|
28 |
-
model_type = None
|
29 |
-
PORT = int(os.environ.get("FLASK_PORT", 5000))
|
30 |
-
MODEL_CONSTRUCTORS = {
|
31 |
-
"huggingface": TextGenerationModel,
|
32 |
-
"sentence_transformers": SentenceTransformerModel,
|
33 |
-
"huggingface_crossmodal": CrossModalEncoderModel,
|
34 |
-
"diffuser": DiffuserModel,
|
35 |
-
}
|
36 |
-
|
37 |
-
|
38 |
-
def parse_args() -> argparse.Namespace:
|
39 |
-
"""Generate args."""
|
40 |
-
parser = argparse.ArgumentParser(description="Model args")
|
41 |
-
parser.add_argument(
|
42 |
-
"--model_type",
|
43 |
-
default=None,
|
44 |
-
type=str,
|
45 |
-
required=True,
|
46 |
-
help="Model type used for finding constructor.",
|
47 |
-
choices=MODEL_CONSTRUCTORS.keys(),
|
48 |
-
)
|
49 |
-
parser.add_argument(
|
50 |
-
"--model_generation_type",
|
51 |
-
default=None,
|
52 |
-
type=str,
|
53 |
-
help="Model generation type.",
|
54 |
-
choices=MODEL_GENTYPE_REGISTRY.keys(),
|
55 |
-
)
|
56 |
-
parser.add_argument(
|
57 |
-
"--model_name_or_path",
|
58 |
-
default=None,
|
59 |
-
type=str,
|
60 |
-
help="Name of model or path to model. Used in initialize of model class.",
|
61 |
-
)
|
62 |
-
parser.add_argument(
|
63 |
-
"--cache_dir", default=None, type=str, help="Cache directory for models."
|
64 |
-
)
|
65 |
-
parser.add_argument(
|
66 |
-
"--device", type=int, default=0, help="Model device. -1 for CPU."
|
67 |
-
)
|
68 |
-
parser.add_argument(
|
69 |
-
"--fp16", action="store_true", help="Force use fp16 for model params."
|
70 |
-
)
|
71 |
-
parser.add_argument(
|
72 |
-
"--percent_max_gpu_mem_reduction",
|
73 |
-
type=float,
|
74 |
-
default=0.85,
|
75 |
-
help="Used with accelerate multigpu. Scales down max memory.",
|
76 |
-
)
|
77 |
-
parser.add_argument(
|
78 |
-
"--use_bitsandbytes",
|
79 |
-
action="store_true",
|
80 |
-
help=("Use bits and bytes. " "This will override --device parameter."),
|
81 |
-
)
|
82 |
-
parser.add_argument(
|
83 |
-
"--use_accelerate_multigpu",
|
84 |
-
action="store_true",
|
85 |
-
help=(
|
86 |
-
"Use accelerate for multi gpu inference. "
|
87 |
-
"This will override --device parameter."
|
88 |
-
),
|
89 |
-
)
|
90 |
-
parser.add_argument(
|
91 |
-
"--use_hf_parallelize",
|
92 |
-
action="store_true",
|
93 |
-
help=(
|
94 |
-
"Use HF parallelize for multi gpu inference. "
|
95 |
-
"This will override --device parameter."
|
96 |
-
),
|
97 |
-
)
|
98 |
-
parser.add_argument(
|
99 |
-
"--use_deepspeed",
|
100 |
-
action="store_true",
|
101 |
-
help=("Use deepspeed. This will override --device parameter."),
|
102 |
-
)
|
103 |
-
args = parser.parse_args()
|
104 |
-
return args
|
105 |
-
|
106 |
-
|
107 |
-
def is_port_in_use(port: int) -> bool:
|
108 |
-
"""Check if port is in use."""
|
109 |
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
110 |
-
return s.connect_ex(("localhost", port)) == 0
|
111 |
-
|
112 |
-
|
113 |
-
def main() -> None:
|
114 |
-
"""Run main."""
|
115 |
-
kwargs = parse_args()
|
116 |
-
if is_port_in_use(PORT):
|
117 |
-
raise ValueError(f"Port {PORT} is already in use.")
|
118 |
-
global model_type
|
119 |
-
model_type = kwargs.model_type
|
120 |
-
model_gen_type = kwargs.model_generation_type
|
121 |
-
model_name_or_path = kwargs.model_name_or_path
|
122 |
-
if not model_name_or_path:
|
123 |
-
raise ValueError("Must provide model_name_or_path.")
|
124 |
-
if kwargs.use_accelerate_multigpu:
|
125 |
-
logger.info("Using accelerate. Overridding --device argument.")
|
126 |
-
if (
|
127 |
-
kwargs.percent_max_gpu_mem_reduction <= 0
|
128 |
-
or kwargs.percent_max_gpu_mem_reduction > 1
|
129 |
-
):
|
130 |
-
raise ValueError("percent_max_gpu_mem_reduction must be in (0, 1].")
|
131 |
-
if (
|
132 |
-
sum(
|
133 |
-
[
|
134 |
-
kwargs.use_accelerate_multigpu,
|
135 |
-
kwargs.use_hf_parallelize,
|
136 |
-
kwargs.use_bitsandbytes,
|
137 |
-
kwargs.use_deepspeed,
|
138 |
-
]
|
139 |
-
)
|
140 |
-
> 1
|
141 |
-
):
|
142 |
-
raise ValueError(
|
143 |
-
"Only one of use_accelerate_multigpu, use_hf_parallelize, "
|
144 |
-
"use_bitsandbytes, and use_deepspeed can be set."
|
145 |
-
)
|
146 |
-
# Global model
|
147 |
-
global model
|
148 |
-
model = MODEL_CONSTRUCTORS[model_type](
|
149 |
-
model_name_or_path,
|
150 |
-
model_type=model_gen_type,
|
151 |
-
cache_dir=kwargs.cache_dir,
|
152 |
-
device=kwargs.device,
|
153 |
-
use_accelerate=kwargs.use_accelerate_multigpu,
|
154 |
-
use_parallelize=kwargs.use_hf_parallelize,
|
155 |
-
use_bitsandbytes=kwargs.use_bitsandbytes,
|
156 |
-
use_deepspeed=kwargs.use_deepspeed,
|
157 |
-
perc_max_gpu_mem_red=kwargs.percent_max_gpu_mem_reduction,
|
158 |
-
use_fp16=kwargs.fp16,
|
159 |
-
)
|
160 |
-
app.run(host="0.0.0.0", port=PORT)
|
161 |
-
|
162 |
-
|
163 |
-
@app.route("/completions", methods=["POST"])
|
164 |
-
def completions() -> Response:
|
165 |
-
"""Get completions for generation."""
|
166 |
-
prompt = request.json["prompt"]
|
167 |
-
del request.json["prompt"]
|
168 |
-
generation_args = request.json
|
169 |
-
|
170 |
-
if not isinstance(prompt, (str, list)):
|
171 |
-
raise ValueError("Prompt must be a str or list of str")
|
172 |
-
try:
|
173 |
-
result_gens = []
|
174 |
-
for generations in model.generate(prompt, **generation_args):
|
175 |
-
result_gens.append(generations)
|
176 |
-
if model_type == "diffuser":
|
177 |
-
# Assign None logprob as it's not supported in diffusers
|
178 |
-
results = [
|
179 |
-
{"array": r[0], "logprob": None, "tokens": None, "token_logprobs": None}
|
180 |
-
for r in result_gens
|
181 |
-
]
|
182 |
-
res_type = "image_generation"
|
183 |
-
else:
|
184 |
-
results = [
|
185 |
-
{"text": r[0], "logprob": r[1], "tokens": r[2], "token_logprobs": r[3]}
|
186 |
-
for r in result_gens
|
187 |
-
]
|
188 |
-
res_type = "text_completion"
|
189 |
-
# transform the result into the openai format
|
190 |
-
return Response(
|
191 |
-
json.dumps(ModelResponse(results, response_type=res_type).__dict__()),
|
192 |
-
status=200,
|
193 |
-
)
|
194 |
-
except Exception as e:
|
195 |
-
logger.error(e)
|
196 |
-
return Response(
|
197 |
-
json.dumps({"message": str(e)}),
|
198 |
-
status=400,
|
199 |
-
)
|
200 |
-
|
201 |
-
|
202 |
-
@app.route("/embed", methods=["POST"])
|
203 |
-
def embed() -> Response:
|
204 |
-
"""Get embed for generation."""
|
205 |
-
if "modality" in request.json:
|
206 |
-
modality = request.json["modality"]
|
207 |
-
else:
|
208 |
-
modality = "text"
|
209 |
-
if modality == "text":
|
210 |
-
prompts = request.json["prompt"]
|
211 |
-
elif modality == "image":
|
212 |
-
import base64
|
213 |
-
|
214 |
-
from PIL import Image
|
215 |
-
|
216 |
-
prompts = [
|
217 |
-
Image.open(io.BytesIO(base64.b64decode(data)))
|
218 |
-
for data in request.json["prompt"]
|
219 |
-
]
|
220 |
-
else:
|
221 |
-
raise ValueError("modality must be text or image")
|
222 |
-
|
223 |
-
try:
|
224 |
-
results = []
|
225 |
-
embeddings = model.embed(prompts)
|
226 |
-
for embedding in embeddings:
|
227 |
-
results.append(
|
228 |
-
{
|
229 |
-
"array": embedding,
|
230 |
-
"logprob": None,
|
231 |
-
"tokens": None,
|
232 |
-
"token_logprobs": None,
|
233 |
-
}
|
234 |
-
)
|
235 |
-
|
236 |
-
return Response(
|
237 |
-
json.dumps(
|
238 |
-
ModelResponse(results, response_type="embedding_generation").__dict__()
|
239 |
-
),
|
240 |
-
status=200,
|
241 |
-
)
|
242 |
-
except Exception as e:
|
243 |
-
logger.error(e)
|
244 |
-
return Response(
|
245 |
-
json.dumps({"message": str(e)}),
|
246 |
-
status=400,
|
247 |
-
)
|
248 |
-
|
249 |
-
|
250 |
-
@app.route("/score_sequence", methods=["POST"])
|
251 |
-
def score_sequence() -> Response:
|
252 |
-
"""Get logprob of prompt."""
|
253 |
-
prompt = request.json["prompt"]
|
254 |
-
del request.json["prompt"]
|
255 |
-
generation_args = request.json
|
256 |
-
|
257 |
-
if not isinstance(prompt, (str, list)):
|
258 |
-
raise ValueError("Prompt must be a str or list of str")
|
259 |
-
|
260 |
-
try:
|
261 |
-
score_list = model.score_sequence(prompt, **generation_args)
|
262 |
-
results = [
|
263 |
-
{
|
264 |
-
"text": prompt if isinstance(prompt, str) else prompt[i],
|
265 |
-
"logprob": r[0],
|
266 |
-
"tokens": r[1],
|
267 |
-
"token_logprobs": r[2],
|
268 |
-
}
|
269 |
-
for i, r in enumerate(score_list)
|
270 |
-
]
|
271 |
-
# transform the result into the openai format
|
272 |
-
return Response(
|
273 |
-
json.dumps(
|
274 |
-
ModelResponse(results, response_type="prompt_logit_score").__dict__()
|
275 |
-
),
|
276 |
-
status=200,
|
277 |
-
)
|
278 |
-
except Exception as e:
|
279 |
-
logger.error(e)
|
280 |
-
return Response(
|
281 |
-
json.dumps({"message": str(e)}),
|
282 |
-
status=400,
|
283 |
-
)
|
284 |
-
|
285 |
-
|
286 |
-
@app.route("/params", methods=["POST"])
|
287 |
-
def params() -> Dict:
|
288 |
-
"""Get model params."""
|
289 |
-
return model.get_init_params()
|
290 |
-
|
291 |
-
|
292 |
-
@app.route("/")
|
293 |
-
def index() -> str:
|
294 |
-
"""Get index completion."""
|
295 |
-
fn = pkg_resources.resource_filename("metaseq", "service/index.html")
|
296 |
-
with open(fn) as f:
|
297 |
-
return f.read()
|
298 |
-
|
299 |
-
|
300 |
-
if __name__ == "__main__":
|
301 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/models/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""Models init."""
|
|
|
|
duckdb-nsql/manifest/manifest/api/models/diffuser.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
"""Diffuser model."""
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
from diffusers import StableDiffusionPipeline
|
8 |
-
|
9 |
-
from manifest.api.models.model import Model
|
10 |
-
|
11 |
-
|
12 |
-
class DiffuserModel(Model):
|
13 |
-
"""Diffuser model."""
|
14 |
-
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
model_name_or_path: str,
|
18 |
-
model_type: Optional[str] = None,
|
19 |
-
model_config: Optional[str] = None,
|
20 |
-
cache_dir: Optional[str] = None,
|
21 |
-
device: int = 0,
|
22 |
-
use_accelerate: bool = False,
|
23 |
-
use_parallelize: bool = False,
|
24 |
-
use_bitsandbytes: bool = False,
|
25 |
-
use_deepspeed: bool = False,
|
26 |
-
perc_max_gpu_mem_red: float = 1.0,
|
27 |
-
use_fp16: bool = False,
|
28 |
-
):
|
29 |
-
"""
|
30 |
-
Initialize model.
|
31 |
-
|
32 |
-
All arguments will be passed in the request from Manifest.
|
33 |
-
|
34 |
-
Args:
|
35 |
-
model_name_or_path: model name string.
|
36 |
-
model_config: model config string.
|
37 |
-
cache_dir: cache directory for model.
|
38 |
-
device: device to use for model.
|
39 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
40 |
-
use_parallelize: use HF default parallelize
|
41 |
-
use_bitsandbytes: use HF bits and bytes
|
42 |
-
use_deepspeed: use deepspeed
|
43 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
44 |
-
use_fp16: use fp16 for model weights.
|
45 |
-
"""
|
46 |
-
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
|
47 |
-
raise ValueError(
|
48 |
-
"Cannot use accelerate or parallelize or "
|
49 |
-
"bitsandbytes or deepspeeed with diffusers"
|
50 |
-
)
|
51 |
-
# Check if providing path
|
52 |
-
self.model_path = model_name_or_path
|
53 |
-
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
|
54 |
-
model_name_or_path = Path(self.model_path).name
|
55 |
-
self.model_name = model_name_or_path
|
56 |
-
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
57 |
-
dtype = torch.float16 if use_fp16 else None
|
58 |
-
torch_device = (
|
59 |
-
torch.device("cpu")
|
60 |
-
if (device == -1 or not torch.cuda.is_available())
|
61 |
-
else torch.device(f"cuda:{device}")
|
62 |
-
)
|
63 |
-
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
64 |
-
self.model_path,
|
65 |
-
torch_dtype=dtype,
|
66 |
-
revision="fp16" if str(dtype) == "float16" else None,
|
67 |
-
)
|
68 |
-
self.pipeline.safety_checker = None
|
69 |
-
self.pipeline.to(torch_device)
|
70 |
-
|
71 |
-
def get_init_params(self) -> Dict:
|
72 |
-
"""Return init params to determine what model is being used."""
|
73 |
-
return {"model_name": self.model_name, "model_path": self.model_path}
|
74 |
-
|
75 |
-
@torch.no_grad()
|
76 |
-
def generate(
|
77 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
78 |
-
) -> List[Tuple[Any, float, List[str], List[float]]]:
|
79 |
-
"""
|
80 |
-
Generate the prompt from model.
|
81 |
-
|
82 |
-
Outputs must be generated text and score, not including prompt.
|
83 |
-
|
84 |
-
Args:
|
85 |
-
prompt: promt to generate from.
|
86 |
-
|
87 |
-
Returns:
|
88 |
-
list of generated text (list of length 1 for 1 generation).
|
89 |
-
"""
|
90 |
-
# TODO: Is this correct for getting arguments in?
|
91 |
-
if isinstance(prompt, str):
|
92 |
-
prompt = [prompt]
|
93 |
-
result = self.pipeline(prompt, output_type="np.array", **kwargs)
|
94 |
-
# Return None for logprobs and token logprobs
|
95 |
-
return [(im, None, None, None) for im in result["images"]]
|
96 |
-
|
97 |
-
@torch.no_grad()
|
98 |
-
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
99 |
-
"""
|
100 |
-
Embed the prompt from model.
|
101 |
-
|
102 |
-
Args:
|
103 |
-
prompt: promt to embed from.
|
104 |
-
|
105 |
-
Returns:
|
106 |
-
list of embeddings (list of length 1 for 1 embedding).
|
107 |
-
"""
|
108 |
-
raise NotImplementedError("Embed not supported for diffusers")
|
109 |
-
|
110 |
-
@torch.no_grad()
|
111 |
-
def score_sequence(
|
112 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
113 |
-
) -> List[Tuple[float, List[int], List[float]]]:
|
114 |
-
"""
|
115 |
-
Score a sequence of choices.
|
116 |
-
|
117 |
-
Args:
|
118 |
-
prompt (:obj:`str` or :obj:`List[str]`):
|
119 |
-
The prompt to score the choices against.
|
120 |
-
**kwargs:
|
121 |
-
Additional keyword arguments passed along to the :obj:`__call__` method.
|
122 |
-
"""
|
123 |
-
raise NotImplementedError("Score sequence not supported for diffusers")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/models/huggingface.py
DELETED
@@ -1,671 +0,0 @@
|
|
1 |
-
"""Huggingface model."""
|
2 |
-
import json
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
5 |
-
|
6 |
-
import deepspeed
|
7 |
-
import numpy as np
|
8 |
-
import PIL
|
9 |
-
import torch
|
10 |
-
from accelerate import dispatch_model, infer_auto_device_map
|
11 |
-
from accelerate.utils.modeling import get_max_memory as acc_get_max_memory
|
12 |
-
from transformers import (
|
13 |
-
AutoModelForCausalLM,
|
14 |
-
AutoModelForSeq2SeqLM,
|
15 |
-
AutoTokenizer,
|
16 |
-
BloomForCausalLM,
|
17 |
-
CLIPModel,
|
18 |
-
CLIPProcessor,
|
19 |
-
GPT2LMHeadModel,
|
20 |
-
GPTJForCausalLM,
|
21 |
-
GPTNeoForCausalLM,
|
22 |
-
GPTNeoXForCausalLM,
|
23 |
-
LlamaForCausalLM,
|
24 |
-
LlamaTokenizer,
|
25 |
-
OPTForCausalLM,
|
26 |
-
PreTrainedModel,
|
27 |
-
PreTrainedTokenizer,
|
28 |
-
)
|
29 |
-
|
30 |
-
from manifest.api.models.model import Model
|
31 |
-
|
32 |
-
MODEL_REGISTRY = {
|
33 |
-
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
|
34 |
-
"EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
|
35 |
-
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
|
36 |
-
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
|
37 |
-
"EleutherAI/gpt-neox-20b": GPTNeoXForCausalLM,
|
38 |
-
"facebook/opt-125m": OPTForCausalLM,
|
39 |
-
"facebook/opt-350m": OPTForCausalLM,
|
40 |
-
"Salesforce/codegen-2B-mono": AutoModelForCausalLM,
|
41 |
-
"Salesforce/codegen-6B-mono": AutoModelForCausalLM,
|
42 |
-
"facebook/opt-1.3b": OPTForCausalLM,
|
43 |
-
"facebook/opt-2.7b": OPTForCausalLM,
|
44 |
-
"facebook/opt-6.7b": OPTForCausalLM,
|
45 |
-
"facebook/opt-13b": OPTForCausalLM,
|
46 |
-
"facebook/opt-30b": OPTForCausalLM,
|
47 |
-
"gpt2": GPT2LMHeadModel,
|
48 |
-
"openai/clip-vit-base-patch32": CLIPModel,
|
49 |
-
"bigscience/bloom-560m": BloomForCausalLM,
|
50 |
-
"bigscience/bloom-1b7": BloomForCausalLM,
|
51 |
-
"bigscience/bloom-3b": BloomForCausalLM,
|
52 |
-
"bigscience/bloom-7b1": BloomForCausalLM,
|
53 |
-
"chainyo/alpaca-lora-7b": LlamaForCausalLM,
|
54 |
-
"bigscience/bloom": AutoModelForCausalLM,
|
55 |
-
"bigscience/T0pp": AutoModelForSeq2SeqLM,
|
56 |
-
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
|
57 |
-
"google/t5-small-lm-adapt": AutoModelForSeq2SeqLM, # 220M
|
58 |
-
"google/t5-l-lm-adapt": AutoModelForSeq2SeqLM, # 800M
|
59 |
-
"google/t5-xl-lm-adapt": AutoModelForSeq2SeqLM, # 3B
|
60 |
-
"google/t5-xxl-lm-adapt": AutoModelForSeq2SeqLM, # 11B
|
61 |
-
"google/t5-v1_1-l": AutoModelForSeq2SeqLM, # 800M
|
62 |
-
"google/t5-v1_1-xl": AutoModelForSeq2SeqLM, # 3B
|
63 |
-
"google/t5-v1_1-xxl": AutoModelForSeq2SeqLM, # 11B
|
64 |
-
"google/flan-t5-l": AutoModelForSeq2SeqLM, # 800M
|
65 |
-
"google/flan-t5-xl": AutoModelForSeq2SeqLM, # 3B
|
66 |
-
"google/flan-t5-xxl": AutoModelForSeq2SeqLM, # 11B
|
67 |
-
}
|
68 |
-
|
69 |
-
MODEL_GENTYPE_REGISTRY = {
|
70 |
-
"text-generation": AutoModelForCausalLM,
|
71 |
-
"llama-text-generation": LlamaForCausalLM,
|
72 |
-
"text2text-generation": AutoModelForSeq2SeqLM,
|
73 |
-
}
|
74 |
-
|
75 |
-
|
76 |
-
def get_max_memory(gpu_reduction: float) -> Dict[int, str]:
|
77 |
-
"""Get max memory in GB times reduction."""
|
78 |
-
free_in_gb = int(torch.cuda.mem_get_info()[0] / 1024**3) # type: ignore
|
79 |
-
max_mem = f"{int(gpu_reduction*free_in_gb)}GB"
|
80 |
-
|
81 |
-
n_gpus = torch.cuda.device_count()
|
82 |
-
max_mem_dict = {i: max_mem for i in range(n_gpus)}
|
83 |
-
return max_mem_dict
|
84 |
-
|
85 |
-
|
86 |
-
class GenerationPipeline:
|
87 |
-
"""
|
88 |
-
Custom Pipeline.
|
89 |
-
|
90 |
-
HF pipelines do not handle devices well in multi-gpu setting.
|
91 |
-
Create our own generation pipeline.
|
92 |
-
"""
|
93 |
-
|
94 |
-
def __init__(
|
95 |
-
self,
|
96 |
-
model: Union[PreTrainedModel, deepspeed.InferenceEngine],
|
97 |
-
tokenizer: PreTrainedTokenizer,
|
98 |
-
device: int = None,
|
99 |
-
bitsandbytes: bool = False,
|
100 |
-
is_encdec: bool = False,
|
101 |
-
):
|
102 |
-
"""Initialize."""
|
103 |
-
# Use to turn off sampling
|
104 |
-
# https://github.com/TimDettmers/bitsandbytes/issues/42
|
105 |
-
self.bitsandbytes = bitsandbytes
|
106 |
-
self.model = model
|
107 |
-
self.is_encdec = is_encdec
|
108 |
-
config = model.config # type: ignore
|
109 |
-
# Used for GPT
|
110 |
-
self.max_length = getattr(config, "max_position_embeddings", None)
|
111 |
-
if self.max_length is None:
|
112 |
-
# Used for Bloom
|
113 |
-
self.max_length = getattr(config, "seq_length", None)
|
114 |
-
if self.max_length is None:
|
115 |
-
# Used for T0
|
116 |
-
self.max_length = getattr(config, "d_model", None)
|
117 |
-
if self.max_length is None:
|
118 |
-
# Default
|
119 |
-
self.max_length = 2048
|
120 |
-
|
121 |
-
print(f"Usings max_length: {self.max_length}")
|
122 |
-
|
123 |
-
self.tokenizer = tokenizer
|
124 |
-
# self.device = device
|
125 |
-
# With bits and bytes, do not want to place inputs on any device
|
126 |
-
# if self.device:
|
127 |
-
self.device = (
|
128 |
-
torch.device("cpu")
|
129 |
-
if (device == -1 or not torch.cuda.is_available())
|
130 |
-
else torch.device(f"cuda:{device}")
|
131 |
-
)
|
132 |
-
|
133 |
-
def __call__(
|
134 |
-
self, text: Union[str, List[str]], **kwargs: Any
|
135 |
-
) -> List[Dict[str, Union[str, List[float], List[str]]]]:
|
136 |
-
"""Generate from text.
|
137 |
-
|
138 |
-
Args:
|
139 |
-
text: text to generate.
|
140 |
-
|
141 |
-
Returns:
|
142 |
-
generated text.
|
143 |
-
"""
|
144 |
-
# If text is longer than max model length, we reduce max input length to ensure
|
145 |
-
# the user indicated generation tokens is preserved.
|
146 |
-
max_input_len = (
|
147 |
-
self.max_length - kwargs.get("max_new_tokens")
|
148 |
-
if not self.is_encdec
|
149 |
-
else self.max_length
|
150 |
-
)
|
151 |
-
encoded_prompt = self.tokenizer(
|
152 |
-
text,
|
153 |
-
max_length=max_input_len,
|
154 |
-
truncation=True,
|
155 |
-
padding=True,
|
156 |
-
return_tensors="pt",
|
157 |
-
)
|
158 |
-
encoded_prompt = encoded_prompt.to(self.device)
|
159 |
-
kwargs_to_pass = dict(
|
160 |
-
temperature=kwargs.get("temperature"),
|
161 |
-
top_k=kwargs.get("top_k"),
|
162 |
-
top_p=kwargs.get("top_p"),
|
163 |
-
repetition_penalty=kwargs.get("repetition_penalty"),
|
164 |
-
num_return_sequences=kwargs.get("num_return_sequences"),
|
165 |
-
do_sample=kwargs.get("do_sample"),
|
166 |
-
)
|
167 |
-
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
|
168 |
-
output_dict = self.model.generate( # type: ignore
|
169 |
-
**encoded_prompt,
|
170 |
-
**kwargs_to_pass,
|
171 |
-
max_new_tokens=kwargs.get("max_new_tokens"),
|
172 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
173 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
174 |
-
output_scores=True,
|
175 |
-
return_dict_in_generate=True,
|
176 |
-
)
|
177 |
-
# logits/scores from the output always correspond to the generated tokens.
|
178 |
-
# shape (num_tokens, num_return_sequences, vocab_size)
|
179 |
-
logits = torch.stack(output_dict.scores)
|
180 |
-
logits = torch.nn.functional.log_softmax(logits, dim=-1)
|
181 |
-
num_generated_tokens = logits.shape[0]
|
182 |
-
generated_sequences = [
|
183 |
-
{
|
184 |
-
"generated_text": self.tokenizer.decode(
|
185 |
-
output_seq[-num_generated_tokens:], skip_special_tokens=True
|
186 |
-
),
|
187 |
-
"logprobs": logits[
|
188 |
-
range(num_generated_tokens), i, output_seq[-num_generated_tokens:]
|
189 |
-
].tolist(),
|
190 |
-
"tokens": self.tokenizer.convert_ids_to_tokens(
|
191 |
-
output_seq[-num_generated_tokens:].tolist()
|
192 |
-
),
|
193 |
-
}
|
194 |
-
for i, output_seq in enumerate(output_dict.sequences)
|
195 |
-
]
|
196 |
-
return generated_sequences
|
197 |
-
|
198 |
-
|
199 |
-
class HuggingFaceModel(Model):
|
200 |
-
"""HuggingFace Model."""
|
201 |
-
|
202 |
-
def __init__(
|
203 |
-
self,
|
204 |
-
model_name_or_path: str,
|
205 |
-
model_type: Optional[str] = None,
|
206 |
-
model_config: Optional[str] = None,
|
207 |
-
cache_dir: Optional[str] = None,
|
208 |
-
device: int = 0,
|
209 |
-
use_accelerate: bool = False,
|
210 |
-
use_parallelize: bool = False,
|
211 |
-
use_bitsandbytes: bool = False,
|
212 |
-
use_deepspeed: bool = False,
|
213 |
-
perc_max_gpu_mem_red: float = 1.0,
|
214 |
-
use_fp16: bool = False,
|
215 |
-
):
|
216 |
-
"""
|
217 |
-
Initialize model.
|
218 |
-
|
219 |
-
All arguments will be passed in the request from Manifest.
|
220 |
-
|
221 |
-
Args:
|
222 |
-
model_name_or_path: model name string.
|
223 |
-
model_config: model config string.
|
224 |
-
cache_dir: cache directory for model.
|
225 |
-
device: device to use for model.
|
226 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
227 |
-
use_parallelize: use HF default parallelize
|
228 |
-
use_bitsandbytes: use HF bits and bytes
|
229 |
-
use_deepspeed: use deepspeed
|
230 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
231 |
-
use_fp16: use fp16 for model weights.
|
232 |
-
"""
|
233 |
-
if sum([use_accelerate, use_parallelize, use_bitsandbytes, use_deepspeed]) > 1:
|
234 |
-
raise ValueError(
|
235 |
-
"Only one of use_accelerate, use_parallelize, "
|
236 |
-
"use_bitsandbytes, use_deepspeed can be set to True"
|
237 |
-
)
|
238 |
-
# Check if providing path
|
239 |
-
self.model_path = model_name_or_path
|
240 |
-
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
|
241 |
-
# Try to find config
|
242 |
-
if (Path(self.model_path) / "config.json").exists():
|
243 |
-
config = json.load(open(Path(self.model_path) / "config.json"))
|
244 |
-
model_name_or_path = config["_name_or_path"]
|
245 |
-
self.model_name = model_name_or_path
|
246 |
-
self.model_type = model_type
|
247 |
-
if self.model_name not in MODEL_REGISTRY and self.model_type is None:
|
248 |
-
raise ValueError(
|
249 |
-
f"{self.model_name} is not in our registry. Please specify "
|
250 |
-
"--model_generation_type as either text-generation (for Causal)"
|
251 |
-
" or text2text-generation (for Seq2Seq)"
|
252 |
-
)
|
253 |
-
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
254 |
-
|
255 |
-
def get_init_params(self) -> Dict:
|
256 |
-
"""Return init params to determine what model is being used."""
|
257 |
-
return {"model_name": self.model_name, "model_path": self.model_path}
|
258 |
-
|
259 |
-
def _dispatch_deepspeed_model(
|
260 |
-
self, model: PreTrainedModel
|
261 |
-
) -> deepspeed.InferenceEngine:
|
262 |
-
"""
|
263 |
-
Load model with deepspeed.
|
264 |
-
|
265 |
-
Adapted from https://www.deepspeed.ai/tutorials/inference-tutorial/
|
266 |
-
|
267 |
-
Args:
|
268 |
-
model: loaded hugging face model
|
269 |
-
"""
|
270 |
-
model = deepspeed.init_inference(
|
271 |
-
model=model,
|
272 |
-
mp_size=1,
|
273 |
-
dtype=model.dtype,
|
274 |
-
replace_method="auto",
|
275 |
-
replace_with_kernel_inject=True,
|
276 |
-
)
|
277 |
-
return model
|
278 |
-
|
279 |
-
def _dispatch_accelerate_model(
|
280 |
-
self, model: PreTrainedModel, perc_max_gpu_mem_red: float
|
281 |
-
) -> None:
|
282 |
-
"""
|
283 |
-
Load model with accelerate.
|
284 |
-
|
285 |
-
Adapted from https://colab.research.google.com/drive/14wnxMvD9zsiBQo2FtT
|
286 |
-
pxn6w2cpXCcb-7#scrollTo=y8Ne7jJdaF9F&uniqifier=1
|
287 |
-
|
288 |
-
Args:
|
289 |
-
model: loaded hugging face model
|
290 |
-
perc_max_gpu_mem_red: percent memory reduction
|
291 |
-
"""
|
292 |
-
model.tie_weights() # type: ignore
|
293 |
-
# Get the model where we can infer devices from
|
294 |
-
if hasattr(model, "model"):
|
295 |
-
# OPT
|
296 |
-
main_model = model.model # type: ignore
|
297 |
-
model_getter = "model."
|
298 |
-
else:
|
299 |
-
# Eleuther Neo and J
|
300 |
-
main_model = model
|
301 |
-
model_getter = ""
|
302 |
-
# Decrease max mem
|
303 |
-
max_memory = {
|
304 |
-
k: int(perc_max_gpu_mem_red * v) for k, v in acc_get_max_memory().items()
|
305 |
-
}
|
306 |
-
raw_device_map = infer_auto_device_map(
|
307 |
-
main_model,
|
308 |
-
max_memory=max_memory,
|
309 |
-
no_split_module_classes=[
|
310 |
-
"OPTDecoderLayer",
|
311 |
-
"GPTNeoBlock",
|
312 |
-
"GPTJBlock",
|
313 |
-
"GPTNeoXLayer",
|
314 |
-
"T5Block",
|
315 |
-
],
|
316 |
-
dtype=model.dtype, # type: ignore
|
317 |
-
)
|
318 |
-
# Hacky fix for Eleuther getting the "weight" of embeddings
|
319 |
-
device_map = {}
|
320 |
-
for k, v in raw_device_map.items():
|
321 |
-
if k in {"wte", "wpe"}:
|
322 |
-
device_map[f"{model_getter}{k}.weight"] = v
|
323 |
-
else:
|
324 |
-
device_map[f"{model_getter}{k}"] = v
|
325 |
-
# For OPT models
|
326 |
-
if "lm_head" not in device_map:
|
327 |
-
try:
|
328 |
-
device_map["lm_head"] = max(device_map.values())
|
329 |
-
except TypeError:
|
330 |
-
device_map["lm_head"] = "cpu"
|
331 |
-
print("Device Map", device_map)
|
332 |
-
dispatch_model(model, device_map=device_map)
|
333 |
-
return
|
334 |
-
|
335 |
-
|
336 |
-
class CrossModalEncoderModel(HuggingFaceModel):
|
337 |
-
"""CrossModalEncoderModel."""
|
338 |
-
|
339 |
-
def __init__(
|
340 |
-
self,
|
341 |
-
model_name_or_path: str,
|
342 |
-
model_type: Optional[str] = None,
|
343 |
-
model_config: Optional[str] = None,
|
344 |
-
cache_dir: Optional[str] = None,
|
345 |
-
device: int = 0,
|
346 |
-
use_accelerate: bool = False,
|
347 |
-
use_parallelize: bool = False,
|
348 |
-
use_bitsandbytes: bool = False,
|
349 |
-
use_deepspeed: bool = False,
|
350 |
-
perc_max_gpu_mem_red: float = 1.0,
|
351 |
-
use_fp16: bool = False,
|
352 |
-
):
|
353 |
-
"""
|
354 |
-
Initialize model.
|
355 |
-
|
356 |
-
All arguments will be passed in the request from Manifest.
|
357 |
-
|
358 |
-
Args:
|
359 |
-
model_name_or_path: model name string.
|
360 |
-
model_config: model config string.
|
361 |
-
cache_dir: cache directory for model.
|
362 |
-
device: device to use for model.
|
363 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
364 |
-
use_parallelize: use HF default parallelize
|
365 |
-
use_bitsandbytes: use HF bits and bytes
|
366 |
-
use_deepspeed: use deepspeed
|
367 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
368 |
-
use_fp16: use fp16 for model weights.
|
369 |
-
"""
|
370 |
-
super().__init__(
|
371 |
-
model_name_or_path,
|
372 |
-
model_type,
|
373 |
-
model_config,
|
374 |
-
cache_dir,
|
375 |
-
device,
|
376 |
-
use_accelerate,
|
377 |
-
use_parallelize,
|
378 |
-
use_bitsandbytes,
|
379 |
-
use_deepspeed,
|
380 |
-
perc_max_gpu_mem_red,
|
381 |
-
use_fp16,
|
382 |
-
)
|
383 |
-
|
384 |
-
# TODO: make this generalizable
|
385 |
-
self.processor = CLIPProcessor.from_pretrained(self.model_path)
|
386 |
-
|
387 |
-
model = MODEL_REGISTRY.get(
|
388 |
-
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
389 |
-
).from_pretrained(
|
390 |
-
self.model_path,
|
391 |
-
cache_dir=cache_dir,
|
392 |
-
trust_remote_code=True,
|
393 |
-
)
|
394 |
-
model.eval()
|
395 |
-
|
396 |
-
torch_device = (
|
397 |
-
torch.device("cpu")
|
398 |
-
if (device == -1 or not torch.cuda.is_available())
|
399 |
-
else torch.device(f"cuda:{device}")
|
400 |
-
)
|
401 |
-
self.model = model.to(torch_device) # type: ignore
|
402 |
-
|
403 |
-
@torch.no_grad()
|
404 |
-
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
405 |
-
"""
|
406 |
-
Compute embedding for prompts.
|
407 |
-
|
408 |
-
Args:
|
409 |
-
prompt: promt to generate from.
|
410 |
-
|
411 |
-
Returns:
|
412 |
-
embedding
|
413 |
-
"""
|
414 |
-
if isinstance(prompt, str):
|
415 |
-
inputs = self.processor(text=prompt, return_tensors="pt", padding=True)
|
416 |
-
elif isinstance(prompt, PIL.Image.Image):
|
417 |
-
inputs = self.processor(images=prompt, return_tensors="pt", padding=True)
|
418 |
-
else:
|
419 |
-
raise ValueError("Prompt must be a string or an image")
|
420 |
-
|
421 |
-
outputs = self.model(**inputs)
|
422 |
-
return outputs
|
423 |
-
|
424 |
-
|
425 |
-
class TextGenerationModel(HuggingFaceModel):
|
426 |
-
"""Huggingface model."""
|
427 |
-
|
428 |
-
def __init__(
|
429 |
-
self,
|
430 |
-
model_name_or_path: str,
|
431 |
-
model_type: Optional[str] = None,
|
432 |
-
model_config: Optional[str] = None,
|
433 |
-
cache_dir: Optional[str] = None,
|
434 |
-
device: int = 0,
|
435 |
-
use_accelerate: bool = False,
|
436 |
-
use_parallelize: bool = False,
|
437 |
-
use_bitsandbytes: bool = False,
|
438 |
-
use_deepspeed: bool = False,
|
439 |
-
perc_max_gpu_mem_red: float = 1.0,
|
440 |
-
use_fp16: bool = False,
|
441 |
-
):
|
442 |
-
"""
|
443 |
-
Initialize model.
|
444 |
-
|
445 |
-
All arguments will be passed in the request from Manifest.
|
446 |
-
|
447 |
-
Args:
|
448 |
-
model_name_or_path: model name string.
|
449 |
-
model_config: model config string.
|
450 |
-
cache_dir: cache directory for model.
|
451 |
-
device: device to use for model.
|
452 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
453 |
-
use_parallelize: use HF default parallelize
|
454 |
-
use_bitsandbytes: use HF bits and bytes
|
455 |
-
use_deepspeed: use deepspeed
|
456 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
457 |
-
use_fp16: use fp16 for model weights.
|
458 |
-
"""
|
459 |
-
super().__init__(
|
460 |
-
model_name_or_path,
|
461 |
-
model_type,
|
462 |
-
model_config,
|
463 |
-
cache_dir,
|
464 |
-
device,
|
465 |
-
use_accelerate,
|
466 |
-
use_parallelize,
|
467 |
-
use_bitsandbytes,
|
468 |
-
use_deepspeed,
|
469 |
-
perc_max_gpu_mem_red,
|
470 |
-
use_fp16,
|
471 |
-
)
|
472 |
-
if (
|
473 |
-
MODEL_REGISTRY.get(
|
474 |
-
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
475 |
-
)
|
476 |
-
== LlamaForCausalLM
|
477 |
-
):
|
478 |
-
tokenizer = LlamaTokenizer.from_pretrained(self.model_name)
|
479 |
-
else:
|
480 |
-
try:
|
481 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
482 |
-
self.model_name, truncation_side="left", padding_side="left"
|
483 |
-
)
|
484 |
-
except ValueError:
|
485 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
486 |
-
self.model_name,
|
487 |
-
truncation_side="left",
|
488 |
-
padding_side="left",
|
489 |
-
use_fast=False,
|
490 |
-
)
|
491 |
-
dtype = torch.float16 if use_fp16 else "auto"
|
492 |
-
if use_bitsandbytes:
|
493 |
-
print("WARNING!!! Cannot use sampling with bitsandbytes.")
|
494 |
-
max_memory = get_max_memory(perc_max_gpu_mem_red)
|
495 |
-
model = MODEL_REGISTRY.get(
|
496 |
-
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
497 |
-
).from_pretrained( # type: ignore
|
498 |
-
self.model_path,
|
499 |
-
cache_dir=cache_dir,
|
500 |
-
load_in_8bit=True,
|
501 |
-
device_map="auto",
|
502 |
-
max_memory=max_memory,
|
503 |
-
trust_remote_code=True,
|
504 |
-
)
|
505 |
-
else:
|
506 |
-
try:
|
507 |
-
# Try to explicitely find a fp16 copy (gpt-j-6B for example)
|
508 |
-
model = MODEL_REGISTRY.get(
|
509 |
-
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
510 |
-
).from_pretrained( # type: ignore
|
511 |
-
self.model_path,
|
512 |
-
cache_dir=cache_dir,
|
513 |
-
revision="float16",
|
514 |
-
torch_dtype=torch.float16,
|
515 |
-
trust_remote_code=True,
|
516 |
-
)
|
517 |
-
except Exception:
|
518 |
-
model = MODEL_REGISTRY.get(
|
519 |
-
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
520 |
-
).from_pretrained( # type: ignore
|
521 |
-
self.model_path,
|
522 |
-
cache_dir=cache_dir,
|
523 |
-
torch_dtype=dtype,
|
524 |
-
trust_remote_code=True,
|
525 |
-
)
|
526 |
-
model.eval()
|
527 |
-
print(f"Loaded Model DType {model.dtype}")
|
528 |
-
self.is_encdec = model.config.is_encoder_decoder
|
529 |
-
if not self.is_encdec:
|
530 |
-
tokenizer.pad_token = tokenizer.eos_token
|
531 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
532 |
-
if not use_bitsandbytes:
|
533 |
-
if use_accelerate:
|
534 |
-
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
|
535 |
-
device = 0
|
536 |
-
elif use_parallelize:
|
537 |
-
model.parallelize()
|
538 |
-
device = 0
|
539 |
-
elif use_deepspeed:
|
540 |
-
self._dispatch_deepspeed_model(model)
|
541 |
-
device = 0
|
542 |
-
else:
|
543 |
-
if device > -1:
|
544 |
-
torch_device = (
|
545 |
-
torch.device("cpu")
|
546 |
-
if (device == -1 or not torch.cuda.is_available())
|
547 |
-
else torch.device(f"cuda:{device}")
|
548 |
-
)
|
549 |
-
model = model.to(torch_device) # type: ignore
|
550 |
-
self.pipeline = GenerationPipeline( # type: ignore
|
551 |
-
model=model,
|
552 |
-
tokenizer=tokenizer,
|
553 |
-
device=device,
|
554 |
-
bitsandbytes=use_bitsandbytes,
|
555 |
-
is_encdec=self.is_encdec,
|
556 |
-
)
|
557 |
-
|
558 |
-
@torch.no_grad()
|
559 |
-
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
560 |
-
"""
|
561 |
-
Embed the prompt from model.
|
562 |
-
|
563 |
-
Args:
|
564 |
-
prompt: promt to embed from.
|
565 |
-
|
566 |
-
Returns:
|
567 |
-
list of embeddings (list of length 1 for 1 embedding).
|
568 |
-
"""
|
569 |
-
if isinstance(prompt, str):
|
570 |
-
prompt = [prompt]
|
571 |
-
encoded_prompt = self.pipeline.tokenizer(
|
572 |
-
prompt,
|
573 |
-
max_length=self.pipeline.max_length,
|
574 |
-
truncation=True,
|
575 |
-
padding=True,
|
576 |
-
return_tensors="pt",
|
577 |
-
)
|
578 |
-
encoded_prompt = encoded_prompt.to(self.pipeline.device)
|
579 |
-
# Get last hidden state
|
580 |
-
output = self.pipeline.model( # type: ignore
|
581 |
-
**encoded_prompt,
|
582 |
-
output_hidden_states=True,
|
583 |
-
return_dict=True,
|
584 |
-
)
|
585 |
-
last_hidden_state = output["hidden_states"][-1][:, -1, :]
|
586 |
-
return last_hidden_state.cpu().numpy()
|
587 |
-
|
588 |
-
@torch.no_grad()
|
589 |
-
def generate(
|
590 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
591 |
-
) -> List[Tuple[Any, float, List[str], List[float]]]:
|
592 |
-
"""
|
593 |
-
Generate the prompt from model.
|
594 |
-
|
595 |
-
Outputs must be generated text and score, not including prompt.
|
596 |
-
|
597 |
-
Args:
|
598 |
-
prompt: promt to generate from.
|
599 |
-
|
600 |
-
Returns:
|
601 |
-
list of generated text (list of length 1 for 1 generation).
|
602 |
-
"""
|
603 |
-
num_return = kwargs.get("n", 1)
|
604 |
-
if isinstance(prompt, list) and num_return > 1:
|
605 |
-
raise ValueError("In batch generate, n must be 1.")
|
606 |
-
result = self.pipeline(
|
607 |
-
prompt,
|
608 |
-
max_new_tokens=kwargs.get("max_tokens"),
|
609 |
-
temperature=kwargs.get("temperature"),
|
610 |
-
repetition_penalty=kwargs.get("repetition_penalty"),
|
611 |
-
top_k=kwargs.get("top_k"),
|
612 |
-
top_p=kwargs.get("top_p"),
|
613 |
-
do_sample=kwargs.get("do_sample"),
|
614 |
-
num_return_sequences=num_return,
|
615 |
-
)
|
616 |
-
final_results = [
|
617 |
-
(
|
618 |
-
cast(str, r["generated_text"]),
|
619 |
-
sum(cast(List[float], r["logprobs"])),
|
620 |
-
cast(List[str], r["tokens"]),
|
621 |
-
cast(List[float], r["logprobs"]),
|
622 |
-
)
|
623 |
-
for r in result
|
624 |
-
]
|
625 |
-
return final_results
|
626 |
-
|
627 |
-
@torch.no_grad()
|
628 |
-
def score_sequence(
|
629 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
630 |
-
) -> List[Tuple[float, List[int], List[float]]]:
|
631 |
-
"""
|
632 |
-
Score a sequence of choices.
|
633 |
-
|
634 |
-
Args:
|
635 |
-
prompt (:obj:`str` or :obj:`List[str]`):
|
636 |
-
The prompt to score the choices against.
|
637 |
-
**kwargs:
|
638 |
-
Additional keyword arguments passed along to the :obj:`__call__` method.
|
639 |
-
"""
|
640 |
-
if isinstance(prompt, str):
|
641 |
-
prompt = [prompt]
|
642 |
-
encoded_prompt = self.pipeline.tokenizer(
|
643 |
-
prompt,
|
644 |
-
max_length=self.pipeline.max_length,
|
645 |
-
truncation=True,
|
646 |
-
padding=True,
|
647 |
-
return_tensors="pt",
|
648 |
-
)
|
649 |
-
encoded_prompt["labels"] = encoded_prompt["input_ids"].clone()
|
650 |
-
encoded_prompt = encoded_prompt.to(self.pipeline.device)
|
651 |
-
logits = self.pipeline.model( # type: ignore
|
652 |
-
**encoded_prompt,
|
653 |
-
).logits
|
654 |
-
# For causal decoders, shift logts and labels
|
655 |
-
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
|
656 |
-
masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
|
657 |
-
logits.float(), dim=-1
|
658 |
-
)
|
659 |
-
seq_token_log_probs = torch.gather(
|
660 |
-
masked_log_probs, -1, encoded_prompt["labels"].unsqueeze(-1)
|
661 |
-
)
|
662 |
-
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
|
663 |
-
seq_log_prob = seq_token_log_probs.sum(dim=-1)
|
664 |
-
return [
|
665 |
-
(seq, tokens, seq_token)
|
666 |
-
for seq, tokens, seq_token in zip(
|
667 |
-
seq_log_prob.tolist(),
|
668 |
-
encoded_prompt["input_ids"].tolist(),
|
669 |
-
seq_token_log_probs.tolist(),
|
670 |
-
)
|
671 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/models/model.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
"""Model class."""
|
2 |
-
from typing import Any, Dict, List, Tuple, Union
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
|
7 |
-
class Model:
|
8 |
-
"""Model class."""
|
9 |
-
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
model_name_or_path: str,
|
13 |
-
model_type: str,
|
14 |
-
cache_dir: str,
|
15 |
-
device: int,
|
16 |
-
use_accelerate: bool,
|
17 |
-
use_parallelize: bool,
|
18 |
-
use_bitsandbytes: bool,
|
19 |
-
use_deepspeed: bool,
|
20 |
-
perc_max_gpu_mem_red: float,
|
21 |
-
use_fp16: bool,
|
22 |
-
):
|
23 |
-
"""
|
24 |
-
Initialize model.
|
25 |
-
|
26 |
-
All arguments will be passed in the request from Manifest.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
model_name_or_path: model name string.
|
30 |
-
model_type: model type string for when model_name not in registry.
|
31 |
-
cache_dir: cache directory for model.
|
32 |
-
device: device to use for model.
|
33 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
34 |
-
use_parallelize: use HF default parallelize
|
35 |
-
use_bitsandbytes: use HF bits and bytes
|
36 |
-
use_deepspeed: use deepspeed
|
37 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
38 |
-
use_fp16: use fp16 for model weights.
|
39 |
-
"""
|
40 |
-
raise NotImplementedError()
|
41 |
-
|
42 |
-
def get_init_params(self) -> Dict:
|
43 |
-
"""Return init params to determine what model is being used."""
|
44 |
-
raise NotImplementedError()
|
45 |
-
|
46 |
-
def generate(
|
47 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
48 |
-
) -> List[Tuple[Any, float, List[str], List[float]]]:
|
49 |
-
"""
|
50 |
-
Generate the prompt from model.
|
51 |
-
|
52 |
-
Outputs must be generated text and score, not including prompt.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
prompt: promt to generate from.
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
list of generated text (list of length 1 for 1 generation).
|
59 |
-
Each item is the response, answer logprob, list of tokens,
|
60 |
-
and list of logprobs for each token.
|
61 |
-
"""
|
62 |
-
raise NotImplementedError()
|
63 |
-
|
64 |
-
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
65 |
-
"""
|
66 |
-
Embed the prompt from model.
|
67 |
-
|
68 |
-
Args:
|
69 |
-
prompt: promt to embed from.
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
list of embeddings (list of length 1 for 1 embedding).
|
73 |
-
"""
|
74 |
-
raise NotImplementedError()
|
75 |
-
|
76 |
-
def score_sequence(
|
77 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
78 |
-
) -> List[Tuple[float, List[int], List[float]]]:
|
79 |
-
"""
|
80 |
-
Score a sequence of choices.
|
81 |
-
|
82 |
-
Args:
|
83 |
-
prompt (:obj:`str` or :obj:`List[str]`):
|
84 |
-
The prompt to score the choices against.
|
85 |
-
**kwargs:
|
86 |
-
Additional keyword arguments passed along to the :obj:`__call__` method.
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
Tuple of total score, tokens, and probs per token.
|
90 |
-
"""
|
91 |
-
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/models/sentence_transformer.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
"""Sentence transformer model."""
|
2 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from sentence_transformers import SentenceTransformer
|
7 |
-
|
8 |
-
from manifest.api.models.model import Model
|
9 |
-
|
10 |
-
|
11 |
-
class SentenceTransformerModel(Model):
|
12 |
-
"""SentenceTransformer model."""
|
13 |
-
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
model_name_or_path: str,
|
17 |
-
model_type: Optional[str] = None,
|
18 |
-
model_config: Optional[str] = None,
|
19 |
-
cache_dir: Optional[str] = None,
|
20 |
-
device: int = 0,
|
21 |
-
use_accelerate: bool = False,
|
22 |
-
use_parallelize: bool = False,
|
23 |
-
use_bitsandbytes: bool = False,
|
24 |
-
use_deepspeed: bool = False,
|
25 |
-
perc_max_gpu_mem_red: float = 1.0,
|
26 |
-
use_fp16: bool = False,
|
27 |
-
):
|
28 |
-
"""
|
29 |
-
Initialize model.
|
30 |
-
|
31 |
-
All arguments will be passed in the request from Manifest.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
model_name_or_path: model name string.
|
35 |
-
model_config: model config string.
|
36 |
-
cache_dir: cache directory for model.
|
37 |
-
device: device to use for model.
|
38 |
-
use_accelerate: whether to use accelerate for multi-gpu inference.
|
39 |
-
use_parallelize: use HF default parallelize
|
40 |
-
use_bitsandbytes: use HF bits and bytes
|
41 |
-
use_deepspeed: use deepspeed
|
42 |
-
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
43 |
-
use_fp16: use fp16 for model weights.
|
44 |
-
"""
|
45 |
-
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
|
46 |
-
raise ValueError(
|
47 |
-
"Cannot use accelerate or parallelize or "
|
48 |
-
"bitsandbytes or deepspeeed with sentence transformers"
|
49 |
-
)
|
50 |
-
# Check if providing path
|
51 |
-
self.model_name = model_name_or_path
|
52 |
-
print("Model Name:", self.model_name)
|
53 |
-
torch_device = (
|
54 |
-
torch.device("cpu")
|
55 |
-
if (device == -1 or not torch.cuda.is_available())
|
56 |
-
else torch.device(f"cuda:{device}")
|
57 |
-
)
|
58 |
-
self.embedding_model = SentenceTransformer(self.model_name, device=torch_device)
|
59 |
-
self.embedding_model.to(torch_device)
|
60 |
-
self.embedding_model.eval()
|
61 |
-
|
62 |
-
def get_init_params(self) -> Dict:
|
63 |
-
"""Return init params to determine what model is being used."""
|
64 |
-
return {"model_name": self.model_name, "model_path": self.model_name}
|
65 |
-
|
66 |
-
@torch.no_grad()
|
67 |
-
def generate(
|
68 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
69 |
-
) -> List[Tuple[Any, float, List[str], List[float]]]:
|
70 |
-
"""
|
71 |
-
Generate the prompt from model.
|
72 |
-
|
73 |
-
Outputs must be generated text and score, not including prompt.
|
74 |
-
|
75 |
-
Args:
|
76 |
-
prompt: promt to generate from.
|
77 |
-
|
78 |
-
Returns:
|
79 |
-
list of generated text (list of length 1 for 1 generation).
|
80 |
-
"""
|
81 |
-
raise NotImplementedError("Generate not supported for sentence transformers")
|
82 |
-
|
83 |
-
@torch.no_grad()
|
84 |
-
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
|
85 |
-
"""
|
86 |
-
Embed the prompt from model.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
prompt: promt to embed from.
|
90 |
-
|
91 |
-
Returns:
|
92 |
-
list of embeddings (list of length 1 for 1 embedding).
|
93 |
-
"""
|
94 |
-
if isinstance(prompt, str):
|
95 |
-
prompt = [prompt]
|
96 |
-
return self.embedding_model.encode(prompt)
|
97 |
-
|
98 |
-
@torch.no_grad()
|
99 |
-
def score_sequence(
|
100 |
-
self, prompt: Union[str, List[str]], **kwargs: Any
|
101 |
-
) -> List[Tuple[float, List[int], List[float]]]:
|
102 |
-
"""
|
103 |
-
Score a sequence of choices.
|
104 |
-
|
105 |
-
Args:
|
106 |
-
prompt (:obj:`str` or :obj:`List[str]`):
|
107 |
-
The prompt to score the choices against.
|
108 |
-
**kwargs:
|
109 |
-
Additional keyword arguments passed along to the :obj:`__call__` method.
|
110 |
-
"""
|
111 |
-
raise NotImplementedError(
|
112 |
-
"Score sequence not supported for sentence transformers"
|
113 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/api/response.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
"""Response."""
|
2 |
-
|
3 |
-
import time
|
4 |
-
import uuid
|
5 |
-
from typing import Any, Dict, List
|
6 |
-
|
7 |
-
|
8 |
-
class ModelResponse:
|
9 |
-
"""ModelResponse."""
|
10 |
-
|
11 |
-
def __init__(self, results: List[Dict[str, Any]], response_type: str) -> None:
|
12 |
-
"""Initialize response."""
|
13 |
-
self.results = results
|
14 |
-
self.response_type = response_type
|
15 |
-
if self.response_type not in {
|
16 |
-
"text_completion",
|
17 |
-
"prompt_logit_score",
|
18 |
-
"image_generation",
|
19 |
-
"embedding_generation",
|
20 |
-
}:
|
21 |
-
raise ValueError(
|
22 |
-
f"Invalid response type: {self.response_type}. "
|
23 |
-
"Must be one of: text_completion, prompt_logit_score, "
|
24 |
-
"image_generation, embedding_generation."
|
25 |
-
)
|
26 |
-
self.response_id = str(uuid.uuid4())
|
27 |
-
self.created = int(time.time())
|
28 |
-
|
29 |
-
def __dict__(self) -> Dict[str, Any]: # type: ignore
|
30 |
-
"""Return dictionary representation of response."""
|
31 |
-
key = (
|
32 |
-
"text"
|
33 |
-
if self.response_type not in {"image_generation", "embedding_generation"}
|
34 |
-
else "array"
|
35 |
-
)
|
36 |
-
return {
|
37 |
-
"id": self.response_id,
|
38 |
-
"object": self.response_type,
|
39 |
-
"created": self.created,
|
40 |
-
"model": "flask_model",
|
41 |
-
"choices": [
|
42 |
-
{
|
43 |
-
key: result[key],
|
44 |
-
"logprob": result["logprob"],
|
45 |
-
"tokens": result["tokens"],
|
46 |
-
"token_logprobs": result["token_logprobs"],
|
47 |
-
}
|
48 |
-
if key == "text"
|
49 |
-
else {
|
50 |
-
key: result[key].tolist(),
|
51 |
-
"logprob": result["logprob"],
|
52 |
-
}
|
53 |
-
for result in self.results
|
54 |
-
],
|
55 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""Cache init."""
|
|
|
|
duckdb-nsql/manifest/manifest/caches/array_cache.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
"""Array cache."""
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
from sqlitedict import SqliteDict
|
7 |
-
|
8 |
-
|
9 |
-
def open_mmap_arr(file: Union[Path, str], size: float) -> np.memmap:
|
10 |
-
"""Open memmap."""
|
11 |
-
if not Path(file).exists():
|
12 |
-
mode = "w+"
|
13 |
-
else:
|
14 |
-
mode = "r+"
|
15 |
-
arr = np.memmap( # type: ignore
|
16 |
-
str(file),
|
17 |
-
dtype=np.float32, # This means we only support float 32
|
18 |
-
mode=mode,
|
19 |
-
shape=size,
|
20 |
-
)
|
21 |
-
return arr
|
22 |
-
|
23 |
-
|
24 |
-
class ArrayCache:
|
25 |
-
"""Array cache."""
|
26 |
-
|
27 |
-
def __init__(self, folder: Union[str, Path]) -> None:
|
28 |
-
"""
|
29 |
-
Initialize the array writer.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
folder: folder to write to.
|
33 |
-
"""
|
34 |
-
self.folder = Path(folder)
|
35 |
-
self.folder.mkdir(exist_ok=True, parents=True)
|
36 |
-
self.hash2arrloc = SqliteDict(
|
37 |
-
self.folder / "hash2arrloc.sqlite", autocommit=True
|
38 |
-
)
|
39 |
-
# Approx 1GB (I think)
|
40 |
-
self.max_memmap_size = 20480000
|
41 |
-
self.cur_file_idx = 0
|
42 |
-
# Get the last file idx used
|
43 |
-
for key in self.hash2arrloc:
|
44 |
-
file_data = self.hash2arrloc[key]
|
45 |
-
if file_data["file_idx"] > self.cur_file_idx:
|
46 |
-
self.cur_file_idx = file_data["file_idx"]
|
47 |
-
self.cur_memmap = open_mmap_arr(
|
48 |
-
self.folder / f"{self.cur_file_idx}.npy",
|
49 |
-
self.max_memmap_size,
|
50 |
-
)
|
51 |
-
# Make sure there is space left in the memmap
|
52 |
-
non_zero = np.nonzero(self.cur_memmap)[0]
|
53 |
-
if len(non_zero) > 0:
|
54 |
-
self.cur_offset = int(np.max(non_zero) + 1)
|
55 |
-
else:
|
56 |
-
self.cur_offset = 0
|
57 |
-
# If no space, make a new memmap
|
58 |
-
if self.cur_offset == self.max_memmap_size:
|
59 |
-
self.cur_file_idx += 1
|
60 |
-
self.cur_memmap = open_mmap_arr(
|
61 |
-
self.folder / f"{self.cur_file_idx}.npy",
|
62 |
-
self.max_memmap_size,
|
63 |
-
)
|
64 |
-
self.cur_offset = 0
|
65 |
-
|
66 |
-
def contains_key(self, key: str) -> bool:
|
67 |
-
"""
|
68 |
-
Check if the key is in the cache.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
key: key to check.
|
72 |
-
|
73 |
-
Returns:
|
74 |
-
True if the key is in the cache.
|
75 |
-
"""
|
76 |
-
return key in self.hash2arrloc
|
77 |
-
|
78 |
-
def put(self, key: str, arr: np.ndarray) -> None:
|
79 |
-
"""Save array in store and associate location with key."""
|
80 |
-
# Check if there is space in the memmap
|
81 |
-
arr_shape = arr.shape
|
82 |
-
arr = arr.flatten()
|
83 |
-
if len(arr) > self.max_memmap_size:
|
84 |
-
raise ValueError(
|
85 |
-
f"Array is too large to be cached. Max is {self.max_memmap_size}"
|
86 |
-
)
|
87 |
-
if self.cur_offset + len(arr) > self.max_memmap_size:
|
88 |
-
self.cur_file_idx += 1
|
89 |
-
self.cur_memmap = open_mmap_arr(
|
90 |
-
self.folder / f"{self.cur_file_idx}.npy",
|
91 |
-
self.max_memmap_size,
|
92 |
-
)
|
93 |
-
self.cur_offset = 0
|
94 |
-
self.cur_memmap[self.cur_offset : self.cur_offset + len(arr)] = arr
|
95 |
-
self.cur_memmap.flush()
|
96 |
-
self.hash2arrloc[key] = {
|
97 |
-
"file_idx": self.cur_file_idx,
|
98 |
-
"offset": self.cur_offset,
|
99 |
-
"flatten_size": len(arr),
|
100 |
-
"shape": arr_shape,
|
101 |
-
"dtype": arr.dtype,
|
102 |
-
}
|
103 |
-
self.cur_offset += len(arr)
|
104 |
-
return
|
105 |
-
|
106 |
-
def get(self, key: str) -> np.ndarray:
|
107 |
-
"""Get array associated with location from key."""
|
108 |
-
file_data = self.hash2arrloc[key]
|
109 |
-
memmap = open_mmap_arr(
|
110 |
-
self.folder / f"{file_data['file_idx']}.npy",
|
111 |
-
self.max_memmap_size,
|
112 |
-
)
|
113 |
-
arr = memmap[
|
114 |
-
file_data["offset"] : file_data["offset"] + file_data["flatten_size"]
|
115 |
-
]
|
116 |
-
return arr.reshape(file_data["shape"]).astype(file_data["dtype"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/cache.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
"""Cache for queries and responses."""
|
2 |
-
from abc import ABC, abstractmethod
|
3 |
-
from typing import Any, Dict, Type, Union
|
4 |
-
|
5 |
-
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
|
6 |
-
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
|
7 |
-
from manifest.response import Response
|
8 |
-
|
9 |
-
# Non-text return type caches
|
10 |
-
ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
|
11 |
-
|
12 |
-
|
13 |
-
class Cache(ABC):
|
14 |
-
"""A cache for request/response pairs."""
|
15 |
-
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
connection_str: str,
|
19 |
-
request_type: Type[Request] = LMRequest,
|
20 |
-
cache_args: Dict[str, Any] = {},
|
21 |
-
):
|
22 |
-
"""
|
23 |
-
Initialize cache.
|
24 |
-
|
25 |
-
Args:
|
26 |
-
connection_str: connection string.
|
27 |
-
request_type: request type.
|
28 |
-
cache_args: arguments for cache.
|
29 |
-
|
30 |
-
cache_args are any arguments needed to initialize the cache.
|
31 |
-
|
32 |
-
Further, cache_args can contain `array_serializer` as a string
|
33 |
-
for embedding or image return types (e.g. diffusers) with values
|
34 |
-
as `local_file` or `byte_string`. `local_file` will save the
|
35 |
-
array in a local file and cache a pointer to the file.
|
36 |
-
`byte_string` will convert the array to a byte string and cache
|
37 |
-
the entire byte string. `byte_string` is default.
|
38 |
-
|
39 |
-
Args:
|
40 |
-
connection_str: connection string for cache.
|
41 |
-
cache_args: cache arguments.
|
42 |
-
"""
|
43 |
-
self.request_type = request_type
|
44 |
-
self.connect(connection_str, cache_args)
|
45 |
-
if self.request_type in ARRAY_CACHE_TYPES:
|
46 |
-
array_serializer = cache_args.pop("array_serializer", "byte_string")
|
47 |
-
if array_serializer not in ["local_file", "byte_string"]:
|
48 |
-
raise ValueError(
|
49 |
-
"array_serializer must be local_file or byte_string,"
|
50 |
-
f" not {array_serializer}"
|
51 |
-
)
|
52 |
-
self.serializer = (
|
53 |
-
ArraySerializer()
|
54 |
-
if array_serializer == "local_file"
|
55 |
-
else NumpyByteSerializer()
|
56 |
-
)
|
57 |
-
else:
|
58 |
-
# If user has array_serializer type, it will throw an error as
|
59 |
-
# it is not recognized for non-array return types.
|
60 |
-
self.serializer = Serializer()
|
61 |
-
|
62 |
-
@abstractmethod
|
63 |
-
def close(self) -> None:
|
64 |
-
"""Close the cache."""
|
65 |
-
raise NotImplementedError()
|
66 |
-
|
67 |
-
@abstractmethod
|
68 |
-
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
69 |
-
"""
|
70 |
-
Connect to cache.
|
71 |
-
|
72 |
-
Args:
|
73 |
-
connection_str: connection string.
|
74 |
-
"""
|
75 |
-
raise NotImplementedError()
|
76 |
-
|
77 |
-
@abstractmethod
|
78 |
-
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
79 |
-
"""
|
80 |
-
Get the key for a request.
|
81 |
-
|
82 |
-
With return None if key is not in cache.
|
83 |
-
|
84 |
-
Args:
|
85 |
-
key: key for cache.
|
86 |
-
table: table to get key in.
|
87 |
-
"""
|
88 |
-
raise NotImplementedError()
|
89 |
-
|
90 |
-
@abstractmethod
|
91 |
-
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
92 |
-
"""
|
93 |
-
Set the value for the key.
|
94 |
-
|
95 |
-
Will override old value.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
key: key for cache.
|
99 |
-
value: new value for key.
|
100 |
-
table: table to set key in.
|
101 |
-
"""
|
102 |
-
raise NotImplementedError()
|
103 |
-
|
104 |
-
@abstractmethod
|
105 |
-
def commit(self) -> None:
|
106 |
-
"""Commit any results."""
|
107 |
-
raise NotImplementedError()
|
108 |
-
|
109 |
-
def get(self, request: Dict) -> Union[Response, None]:
|
110 |
-
"""Get the result of request (by calling compute as needed).
|
111 |
-
|
112 |
-
Args:
|
113 |
-
request: request to get.
|
114 |
-
response: response to get.
|
115 |
-
|
116 |
-
Returns:
|
117 |
-
Response object or None if not in cache.
|
118 |
-
"""
|
119 |
-
key = self.serializer.request_to_key(request)
|
120 |
-
cached_response = self.get_key(key)
|
121 |
-
if cached_response:
|
122 |
-
response = self.serializer.key_to_response(cached_response)
|
123 |
-
response["cached"] = True
|
124 |
-
return Response.from_dict(response, request_dict=request)
|
125 |
-
return None
|
126 |
-
|
127 |
-
def set(self, request: Dict, response: Dict) -> None:
|
128 |
-
"""Set the value for the key.
|
129 |
-
|
130 |
-
Args:
|
131 |
-
request: request to set.
|
132 |
-
response: response to set.
|
133 |
-
"""
|
134 |
-
key = self.serializer.request_to_key(request)
|
135 |
-
self.set_key(key, self.serializer.response_to_key(response))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/noop.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
"""Noop cache."""
|
2 |
-
from typing import Any, Dict, Union
|
3 |
-
|
4 |
-
from manifest.caches.cache import Cache
|
5 |
-
|
6 |
-
|
7 |
-
class NoopCache(Cache):
|
8 |
-
"""A Noop cache that caches nothing for request/response pairs."""
|
9 |
-
|
10 |
-
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
11 |
-
"""
|
12 |
-
Connect to client.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
connection_str: connection string.
|
16 |
-
cache_args: arguments for cache.
|
17 |
-
"""
|
18 |
-
pass
|
19 |
-
|
20 |
-
def close(self) -> None:
|
21 |
-
"""Close the client."""
|
22 |
-
pass
|
23 |
-
|
24 |
-
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
25 |
-
"""
|
26 |
-
Return None key for never in cache.
|
27 |
-
|
28 |
-
Args:
|
29 |
-
key: key for cache.
|
30 |
-
table: table to get key in.
|
31 |
-
"""
|
32 |
-
return None
|
33 |
-
|
34 |
-
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
35 |
-
"""
|
36 |
-
Do not set anything as no cache.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
key: key for cache.
|
40 |
-
value: new value for key.
|
41 |
-
table: table to set key in.
|
42 |
-
"""
|
43 |
-
pass
|
44 |
-
|
45 |
-
def commit(self) -> None:
|
46 |
-
"""Commit any results."""
|
47 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/postgres.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
"""Postgres cache."""
|
2 |
-
import hashlib
|
3 |
-
import logging
|
4 |
-
from typing import Any, Dict, Union
|
5 |
-
|
6 |
-
logger = logging.getLogger("postgresql")
|
7 |
-
logger.setLevel(logging.WARNING)
|
8 |
-
|
9 |
-
from ..caches.cache import Cache
|
10 |
-
|
11 |
-
try:
|
12 |
-
import sqlalchemy # type: ignore
|
13 |
-
from google.cloud.sql.connector import Connector # type: ignore
|
14 |
-
from sqlalchemy import Column, String # type: ignore
|
15 |
-
from sqlalchemy.ext.declarative import declarative_base # type: ignore
|
16 |
-
from sqlalchemy.orm import sessionmaker # type: ignore
|
17 |
-
|
18 |
-
Base = declarative_base()
|
19 |
-
|
20 |
-
class Request(Base): # type: ignore
|
21 |
-
"""The request table."""
|
22 |
-
|
23 |
-
__tablename__ = "requests"
|
24 |
-
key = Column(String, primary_key=True)
|
25 |
-
response = Column(
|
26 |
-
String
|
27 |
-
) # FIXME: ideally should be an hstore, but I don't want to set it up on GCP
|
28 |
-
|
29 |
-
missing_dependencies = None
|
30 |
-
|
31 |
-
except ImportError as e:
|
32 |
-
missing_dependencies = e
|
33 |
-
|
34 |
-
|
35 |
-
class PostgresCache(Cache):
|
36 |
-
"""A PostgreSQL cache for request/response pairs."""
|
37 |
-
|
38 |
-
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
39 |
-
"""
|
40 |
-
Connect to client.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
connection_str: connection string.
|
44 |
-
cache_args: arguments for cache should include the following fields:
|
45 |
-
{
|
46 |
-
"cache_user": "",
|
47 |
-
"cache_password": "",
|
48 |
-
"cache_db": ""
|
49 |
-
}
|
50 |
-
"""
|
51 |
-
if missing_dependencies:
|
52 |
-
raise ValueError(
|
53 |
-
"Missing dependencies for GCP PostgreSQL cache. "
|
54 |
-
"Install with `pip install manifest[gcp]`",
|
55 |
-
missing_dependencies,
|
56 |
-
)
|
57 |
-
|
58 |
-
connector = Connector()
|
59 |
-
|
60 |
-
def getconn() -> Any:
|
61 |
-
conn = connector.connect(
|
62 |
-
connection_str,
|
63 |
-
"pg8000",
|
64 |
-
user=cache_args.pop("cache_user"),
|
65 |
-
password=cache_args.pop("cache_password"),
|
66 |
-
db=cache_args.pop("cache_db"),
|
67 |
-
)
|
68 |
-
return conn
|
69 |
-
|
70 |
-
engine = sqlalchemy.create_engine(
|
71 |
-
"postgresql+pg8000://",
|
72 |
-
creator=getconn,
|
73 |
-
)
|
74 |
-
engine.dialect.description_encoding = None # type: ignore
|
75 |
-
|
76 |
-
db_exists = len(sqlalchemy.inspect(engine).get_table_names()) > 0
|
77 |
-
if not db_exists:
|
78 |
-
logger.info("Creating database...")
|
79 |
-
Base.metadata.create_all(engine)
|
80 |
-
|
81 |
-
self.session = sessionmaker(bind=engine)()
|
82 |
-
|
83 |
-
def close(self) -> None:
|
84 |
-
"""Close the client."""
|
85 |
-
self.session.close()
|
86 |
-
|
87 |
-
@staticmethod
|
88 |
-
def _hash_key(key: str, table: str) -> str:
|
89 |
-
"""Compute MD5 hash of the key."""
|
90 |
-
return hashlib.md5(f"{key}:{table}".encode("utf-8")).hexdigest()
|
91 |
-
|
92 |
-
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
93 |
-
"""
|
94 |
-
Get the key for a request.
|
95 |
-
|
96 |
-
With return None if key is not in cache.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
key: key for cache.
|
100 |
-
table: table to get key in.
|
101 |
-
"""
|
102 |
-
request = (
|
103 |
-
self.session.query(Request) # type: ignore
|
104 |
-
.filter_by(key=self._hash_key(key, table))
|
105 |
-
.first()
|
106 |
-
)
|
107 |
-
out = request.response if request else None
|
108 |
-
return out # type: ignore
|
109 |
-
|
110 |
-
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
111 |
-
"""
|
112 |
-
Set the value for the key.
|
113 |
-
|
114 |
-
Will override old value.
|
115 |
-
|
116 |
-
Args:
|
117 |
-
key: key for cache.
|
118 |
-
value: new value for key.
|
119 |
-
table: table to set key in.
|
120 |
-
"""
|
121 |
-
key = self._hash_key(key, table)
|
122 |
-
request = self.session.query(Request).filter_by(key=key).first() # type: ignore
|
123 |
-
if request:
|
124 |
-
request.response = value # type: ignore
|
125 |
-
else:
|
126 |
-
self.session.add(Request(key=key, response=value))
|
127 |
-
self.commit()
|
128 |
-
|
129 |
-
def commit(self) -> None:
|
130 |
-
"""Commit any results."""
|
131 |
-
self.session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/redis.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
"""Redis cache."""
|
2 |
-
from typing import Any, Dict, Union
|
3 |
-
|
4 |
-
import redis
|
5 |
-
|
6 |
-
from manifest.caches.cache import Cache
|
7 |
-
|
8 |
-
|
9 |
-
class RedisCache(Cache):
|
10 |
-
"""A Redis cache for request/response pairs."""
|
11 |
-
|
12 |
-
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
13 |
-
"""
|
14 |
-
Connect to client.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
connection_str: connection string.
|
18 |
-
cache_args: arguments for cache.
|
19 |
-
"""
|
20 |
-
host, port = connection_str.split(":")
|
21 |
-
self.redis = redis.Redis(host=host, port=int(port), db=0)
|
22 |
-
return
|
23 |
-
|
24 |
-
def close(self) -> None:
|
25 |
-
"""Close the client."""
|
26 |
-
self.redis.close()
|
27 |
-
|
28 |
-
def _normalize_table_key(self, key: str, table: str) -> str:
|
29 |
-
"""Cast key for prompt key."""
|
30 |
-
return f"{table}:{key}"
|
31 |
-
|
32 |
-
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
33 |
-
"""
|
34 |
-
Get the key for a request.
|
35 |
-
|
36 |
-
With return None if key is not in cache.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
key: key for cache.
|
40 |
-
table: table to get key in.
|
41 |
-
"""
|
42 |
-
norm_key = self._normalize_table_key(key, table)
|
43 |
-
if self.redis.exists(norm_key):
|
44 |
-
return self.redis.get(norm_key).decode("utf-8")
|
45 |
-
else:
|
46 |
-
return None
|
47 |
-
|
48 |
-
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
49 |
-
"""
|
50 |
-
Set the value for the key.
|
51 |
-
|
52 |
-
Will override old value.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
key: key for cache.
|
56 |
-
value: new value for key.
|
57 |
-
table: table to set key in.
|
58 |
-
"""
|
59 |
-
self.redis.set(self._normalize_table_key(key, table), value)
|
60 |
-
self.commit()
|
61 |
-
|
62 |
-
def commit(self) -> None:
|
63 |
-
"""Commit any results."""
|
64 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/serializers.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
"""Serializer."""
|
2 |
-
|
3 |
-
import io
|
4 |
-
import json
|
5 |
-
import os
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Dict
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import xxhash
|
11 |
-
|
12 |
-
from manifest.caches.array_cache import ArrayCache
|
13 |
-
|
14 |
-
|
15 |
-
class Serializer:
|
16 |
-
"""Serializer."""
|
17 |
-
|
18 |
-
def request_to_key(self, request: Dict) -> str:
|
19 |
-
"""
|
20 |
-
Normalize a request into a key.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
request: request to normalize.
|
24 |
-
|
25 |
-
Returns:
|
26 |
-
normalized key.
|
27 |
-
"""
|
28 |
-
return json.dumps(request, sort_keys=True)
|
29 |
-
|
30 |
-
def key_to_request(self, key: str) -> Dict:
|
31 |
-
"""
|
32 |
-
Convert the normalized version to the request.
|
33 |
-
|
34 |
-
Args:
|
35 |
-
key: normalized key to convert.
|
36 |
-
|
37 |
-
Returns:
|
38 |
-
unnormalized request dict.
|
39 |
-
"""
|
40 |
-
return json.loads(key)
|
41 |
-
|
42 |
-
def response_to_key(self, response: Dict) -> str:
|
43 |
-
"""
|
44 |
-
Normalize a response into a key.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
response: response to normalize.
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
normalized key.
|
51 |
-
"""
|
52 |
-
return json.dumps(response, sort_keys=True)
|
53 |
-
|
54 |
-
def key_to_response(self, key: str) -> Dict:
|
55 |
-
"""
|
56 |
-
Convert the normalized version to the response.
|
57 |
-
|
58 |
-
Args:
|
59 |
-
key: normalized key to convert.
|
60 |
-
|
61 |
-
Returns:
|
62 |
-
unnormalized response dict.
|
63 |
-
"""
|
64 |
-
return json.loads(key)
|
65 |
-
|
66 |
-
|
67 |
-
class NumpyByteSerializer(Serializer):
|
68 |
-
"""Serializer by casting array to byte string."""
|
69 |
-
|
70 |
-
def response_to_key(self, response: Dict) -> str:
|
71 |
-
"""
|
72 |
-
Normalize a response into a key.
|
73 |
-
|
74 |
-
Args:
|
75 |
-
response: response to normalize.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
normalized key.
|
79 |
-
"""
|
80 |
-
sub_response = response["response"]
|
81 |
-
# Assume response is a dict with keys "choices" -> List dicts
|
82 |
-
# with keys "array".
|
83 |
-
choices = sub_response["choices"]
|
84 |
-
# We don't want to modify the response in place
|
85 |
-
# but we want to avoid calling deepcopy on an array
|
86 |
-
del sub_response["choices"]
|
87 |
-
response_copy = sub_response.copy()
|
88 |
-
sub_response["choices"] = choices
|
89 |
-
response_copy["choices"] = []
|
90 |
-
for choice in choices:
|
91 |
-
if "array" not in choice:
|
92 |
-
raise ValueError(
|
93 |
-
f"Choice with keys {choice.keys()} does not have array key."
|
94 |
-
)
|
95 |
-
arr = choice["array"]
|
96 |
-
# Avoid copying an array
|
97 |
-
del choice["array"]
|
98 |
-
new_choice = choice.copy()
|
99 |
-
choice["array"] = arr
|
100 |
-
with io.BytesIO() as f:
|
101 |
-
np.savez_compressed(f, data=arr)
|
102 |
-
hash_str = f.getvalue().hex()
|
103 |
-
new_choice["array"] = hash_str
|
104 |
-
response_copy["choices"].append(new_choice)
|
105 |
-
response["response"] = response_copy
|
106 |
-
return json.dumps(response, sort_keys=True)
|
107 |
-
|
108 |
-
def key_to_response(self, key: str) -> Dict:
|
109 |
-
"""
|
110 |
-
Convert the normalized version to the response.
|
111 |
-
|
112 |
-
Args:
|
113 |
-
key: normalized key to convert.
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
unnormalized response dict.
|
117 |
-
"""
|
118 |
-
response = json.loads(key)
|
119 |
-
for choice in response["response"]["choices"]:
|
120 |
-
hash_str = choice["array"]
|
121 |
-
byte_str = bytes.fromhex(hash_str)
|
122 |
-
with io.BytesIO(byte_str) as f:
|
123 |
-
choice["array"] = np.load(f)["data"]
|
124 |
-
return response
|
125 |
-
|
126 |
-
|
127 |
-
class ArraySerializer(Serializer):
|
128 |
-
"""Serializer for array."""
|
129 |
-
|
130 |
-
def __init__(self) -> None:
|
131 |
-
"""
|
132 |
-
Initialize array serializer.
|
133 |
-
|
134 |
-
We don't want to cache the array. We hash the value and
|
135 |
-
store the array in a memmap file. Store filename/offsets
|
136 |
-
in sqlitedict to keep track of hash -> array.
|
137 |
-
"""
|
138 |
-
super().__init__()
|
139 |
-
|
140 |
-
self.hash = xxhash.xxh64()
|
141 |
-
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
|
142 |
-
cache_folder = manifest_home / ".manifest" / "array_cache"
|
143 |
-
self.writer = ArrayCache(cache_folder)
|
144 |
-
|
145 |
-
def response_to_key(self, response: Dict) -> str:
|
146 |
-
"""
|
147 |
-
Normalize a response into a key.
|
148 |
-
|
149 |
-
Convert arrays to hash string for cache key.
|
150 |
-
|
151 |
-
Args:
|
152 |
-
response: response to normalize.
|
153 |
-
|
154 |
-
Returns:
|
155 |
-
normalized key.
|
156 |
-
"""
|
157 |
-
sub_response = response["response"]
|
158 |
-
# Assume response is a dict with keys "choices" -> List dicts
|
159 |
-
# with keys "array".
|
160 |
-
choices = sub_response["choices"]
|
161 |
-
# We don't want to modify the response in place
|
162 |
-
# but we want to avoid calling deepcopy on an array
|
163 |
-
del sub_response["choices"]
|
164 |
-
response_copy = sub_response.copy()
|
165 |
-
sub_response["choices"] = choices
|
166 |
-
response_copy["choices"] = []
|
167 |
-
for choice in choices:
|
168 |
-
if "array" not in choice:
|
169 |
-
raise ValueError(
|
170 |
-
f"Choice with keys {choice.keys()} does not have array key."
|
171 |
-
)
|
172 |
-
arr = choice["array"]
|
173 |
-
# Avoid copying an array
|
174 |
-
del choice["array"]
|
175 |
-
new_choice = choice.copy()
|
176 |
-
choice["array"] = arr
|
177 |
-
|
178 |
-
self.hash.update(arr)
|
179 |
-
hash_str = self.hash.hexdigest()
|
180 |
-
self.hash.reset()
|
181 |
-
new_choice["array"] = hash_str
|
182 |
-
response_copy["choices"].append(new_choice)
|
183 |
-
if not self.writer.contains_key(hash_str):
|
184 |
-
self.writer.put(hash_str, arr)
|
185 |
-
response["response"] = response_copy
|
186 |
-
return json.dumps(response, sort_keys=True)
|
187 |
-
|
188 |
-
def key_to_response(self, key: str) -> Dict:
|
189 |
-
"""
|
190 |
-
Convert the normalized version to the response.
|
191 |
-
|
192 |
-
Convert the hash string keys to the arrays.
|
193 |
-
|
194 |
-
Args:
|
195 |
-
key: normalized key to convert.
|
196 |
-
|
197 |
-
Returns:
|
198 |
-
unnormalized response dict.
|
199 |
-
"""
|
200 |
-
response = json.loads(key)
|
201 |
-
for choice in response["response"]["choices"]:
|
202 |
-
hash_str = choice["array"]
|
203 |
-
choice["array"] = self.writer.get(hash_str)
|
204 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/caches/sqlite.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
"""SQLite cache."""
|
2 |
-
import logging
|
3 |
-
from typing import Any, Dict, Union
|
4 |
-
|
5 |
-
from sqlitedict import SqliteDict
|
6 |
-
|
7 |
-
from manifest.caches.cache import Cache
|
8 |
-
|
9 |
-
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
|
10 |
-
|
11 |
-
|
12 |
-
class SQLiteCache(Cache):
|
13 |
-
"""A SQLite cache for request/response pairs."""
|
14 |
-
|
15 |
-
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
16 |
-
"""
|
17 |
-
Connect to client.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
connection_str: connection string.
|
21 |
-
cache_args: arguments for cache.
|
22 |
-
"""
|
23 |
-
self.cache_file = connection_str
|
24 |
-
if not self.cache_file:
|
25 |
-
self.cache_file = ".sqlite.cache"
|
26 |
-
self.cache = SqliteDict(self.cache_file, autocommit=True)
|
27 |
-
return
|
28 |
-
|
29 |
-
def close(self) -> None:
|
30 |
-
"""Close the client."""
|
31 |
-
self.cache.close()
|
32 |
-
|
33 |
-
def _normalize_table_key(self, key: str, table: str) -> str:
|
34 |
-
"""Cast key for prompt key."""
|
35 |
-
return f"{table}:{key}"
|
36 |
-
|
37 |
-
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
38 |
-
"""
|
39 |
-
Get the key for a request.
|
40 |
-
|
41 |
-
With return None if key is not in cache.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
key: key for cache.
|
45 |
-
table: table to get key in.
|
46 |
-
"""
|
47 |
-
return self.cache.get(self._normalize_table_key(key, table))
|
48 |
-
|
49 |
-
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
50 |
-
"""
|
51 |
-
Set the value for the key.
|
52 |
-
|
53 |
-
Will override old value.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
key: key for cache.
|
57 |
-
value: new value for key.
|
58 |
-
table: table to set key in.
|
59 |
-
"""
|
60 |
-
self.cache[self._normalize_table_key(key, table)] = value
|
61 |
-
self.commit()
|
62 |
-
|
63 |
-
def commit(self) -> None:
|
64 |
-
"""Commit any results."""
|
65 |
-
self.cache.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""Client init."""
|
|
|
|
duckdb-nsql/manifest/manifest/clients/ai21.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
"""AI21 client."""
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from typing import Any, Dict, Optional
|
5 |
-
|
6 |
-
from manifest.clients.client import Client
|
7 |
-
from manifest.request import LMRequest
|
8 |
-
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
-
AI21_ENGINES = {
|
12 |
-
"j2-ultra",
|
13 |
-
"j2-mid",
|
14 |
-
"j2-light",
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
class AI21Client(Client):
|
19 |
-
"""AI21Client client."""
|
20 |
-
|
21 |
-
# User param -> (client param, default value)
|
22 |
-
PARAMS = {
|
23 |
-
"engine": ("engine", "j2-ultra"),
|
24 |
-
"temperature": ("temperature", 0.7),
|
25 |
-
"max_tokens": ("maxTokens", 40),
|
26 |
-
"top_k": ("topKReturn", 0),
|
27 |
-
"n": ("numResults", 1),
|
28 |
-
"top_p": ("topP", 1.0),
|
29 |
-
"stop_sequences": ("stopSequences", []),
|
30 |
-
}
|
31 |
-
REQUEST_CLS = LMRequest
|
32 |
-
NAME = "ai21"
|
33 |
-
|
34 |
-
def connect(
|
35 |
-
self,
|
36 |
-
connection_str: Optional[str] = None,
|
37 |
-
client_args: Dict[str, Any] = {},
|
38 |
-
) -> None:
|
39 |
-
"""
|
40 |
-
Connect to the AI21 server.
|
41 |
-
|
42 |
-
connection_str is passed as default AI21_API_KEY if variable not set.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
connection_str: connection string.
|
46 |
-
client_args: client arguments.
|
47 |
-
"""
|
48 |
-
# Taken from https://docs.ai21.com/
|
49 |
-
self.host = "https://api.ai21.com/studio/v1"
|
50 |
-
self.api_key = connection_str or os.environ.get("AI21_API_KEY")
|
51 |
-
if self.api_key is None:
|
52 |
-
raise ValueError(
|
53 |
-
"AI21 API key not set. Set AI21_API_KEY environment "
|
54 |
-
"variable or pass through `client_connection`."
|
55 |
-
)
|
56 |
-
|
57 |
-
for key in self.PARAMS:
|
58 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
59 |
-
if getattr(self, "engine") not in AI21_ENGINES:
|
60 |
-
raise ValueError(
|
61 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {AI21_ENGINES}."
|
62 |
-
)
|
63 |
-
|
64 |
-
def close(self) -> None:
|
65 |
-
"""Close the client."""
|
66 |
-
pass
|
67 |
-
|
68 |
-
def get_generation_url(self) -> str:
|
69 |
-
"""Get generation URL."""
|
70 |
-
return self.host + "/" + getattr(self, "engine") + "/complete"
|
71 |
-
|
72 |
-
def get_generation_header(self) -> Dict[str, str]:
|
73 |
-
"""
|
74 |
-
Get generation header.
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
header.
|
78 |
-
"""
|
79 |
-
return {"Authorization": f"Bearer {self.api_key}"}
|
80 |
-
|
81 |
-
def supports_batch_inference(self) -> bool:
|
82 |
-
"""Return whether the client supports batch inference."""
|
83 |
-
return False
|
84 |
-
|
85 |
-
def supports_streaming_inference(self) -> bool:
|
86 |
-
"""Return whether the client supports streaming inference.
|
87 |
-
|
88 |
-
Override in child client class.
|
89 |
-
"""
|
90 |
-
return False
|
91 |
-
|
92 |
-
def get_model_params(self) -> Dict:
|
93 |
-
"""
|
94 |
-
Get model params.
|
95 |
-
|
96 |
-
By getting model params from the server, we can add to request
|
97 |
-
and make sure cache keys are unique to model.
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
model params.
|
101 |
-
"""
|
102 |
-
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
103 |
-
|
104 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
105 |
-
"""
|
106 |
-
Format response to dict.
|
107 |
-
|
108 |
-
Args:
|
109 |
-
response: response
|
110 |
-
request: request
|
111 |
-
|
112 |
-
Return:
|
113 |
-
response as dict
|
114 |
-
"""
|
115 |
-
return {
|
116 |
-
"object": "text_completion",
|
117 |
-
"model": getattr(self, "engine"),
|
118 |
-
"choices": [
|
119 |
-
{
|
120 |
-
"text": item["data"]["text"],
|
121 |
-
"token_logprobs": item["data"]["tokens"],
|
122 |
-
}
|
123 |
-
for item in response["completions"]
|
124 |
-
],
|
125 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/azureendpoint.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
"""OpenRouter client."""
|
2 |
-
|
3 |
-
import copy
|
4 |
-
import logging
|
5 |
-
import os
|
6 |
-
from typing import Any, Dict, Optional
|
7 |
-
import time
|
8 |
-
from manifest.clients.client import Client
|
9 |
-
from manifest.request import LMRequest
|
10 |
-
import urllib.request
|
11 |
-
import json
|
12 |
-
import os
|
13 |
-
import ssl
|
14 |
-
|
15 |
-
logger = logging.getLogger(__name__)
|
16 |
-
def allowSelfSignedHttps(allowed):
|
17 |
-
# bypass the server certificate verification on client side
|
18 |
-
if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None):
|
19 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
20 |
-
|
21 |
-
allowSelfSignedHttps(True) # this line is needed if you use self-signed certificate in your scoring service.
|
22 |
-
|
23 |
-
|
24 |
-
class AzureEndpointClient(Client):
|
25 |
-
"""OpenRouter client."""
|
26 |
-
|
27 |
-
# Params are defined in https://openrouter.ai/docs/parameters
|
28 |
-
PARAMS = {
|
29 |
-
"engine": ("model", "meta-llama/codellama-70b-instruct"),
|
30 |
-
"max_tokens": ("max_tokens", 1000),
|
31 |
-
"temperature": ("temperature", 0.1),
|
32 |
-
"top_k": ("k", 0),
|
33 |
-
"frequency_penalty": ("frequency_penalty", 0.0),
|
34 |
-
"presence_penalty": ("presence_penalty", 0.0),
|
35 |
-
"stop_sequences": ("stop", None),
|
36 |
-
}
|
37 |
-
REQUEST_CLS = LMRequest
|
38 |
-
NAME = "azureendpoint"
|
39 |
-
IS_CHAT = True
|
40 |
-
|
41 |
-
def connect(
|
42 |
-
self,
|
43 |
-
connection_str: Optional[str] = None,
|
44 |
-
client_args: Dict[str, Any] = {},
|
45 |
-
) -> None:
|
46 |
-
"""
|
47 |
-
Connect to the OpenRouter server.
|
48 |
-
|
49 |
-
connection_str is passed as default OPENROUTER_API_KEY if variable not set.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
connection_str: connection string.
|
53 |
-
client_args: client arguments.
|
54 |
-
"""
|
55 |
-
|
56 |
-
self.host = os.environ.get("AZURE_HOST")
|
57 |
-
# Replace this with the primary/secondary key, AMLToken, or Microsoft Entra ID token for the endpoint
|
58 |
-
self.api_key = os.environ.get("AZURE_API_KEY")
|
59 |
-
if not self.api_key:
|
60 |
-
raise Exception("A key should be provided to invoke the endpoint")
|
61 |
-
for key in self.PARAMS:
|
62 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
63 |
-
|
64 |
-
def close(self) -> None:
|
65 |
-
"""Close the client."""
|
66 |
-
|
67 |
-
def get_generation_header(self) -> Dict[str, str]:
|
68 |
-
"""
|
69 |
-
Get generation header.
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
header.
|
73 |
-
"""
|
74 |
-
return {'Content-Type':'application/json', 'Authorization':('Bearer '+ self.api_key), 'azureml-model-deployment': 'duckdb-nsql-v2-phi-medium-1' }
|
75 |
-
|
76 |
-
def get_generation_url(self) -> str:
|
77 |
-
"""Get generation URL."""
|
78 |
-
return self.host + "/score"
|
79 |
-
|
80 |
-
def supports_batch_inference(self) -> bool:
|
81 |
-
"""Return whether the client supports batch inference."""
|
82 |
-
return False
|
83 |
-
|
84 |
-
def supports_streaming_inference(self) -> bool:
|
85 |
-
"""Return whether the client supports streaming inference.
|
86 |
-
|
87 |
-
Override in child client class.
|
88 |
-
"""
|
89 |
-
return True
|
90 |
-
|
91 |
-
def get_model_params(self) -> Dict:
|
92 |
-
"""
|
93 |
-
Get model params.
|
94 |
-
|
95 |
-
By getting model params from the server, we can add to request
|
96 |
-
and make sure cache keys are unique to model.
|
97 |
-
|
98 |
-
Returns:
|
99 |
-
model params.
|
100 |
-
"""
|
101 |
-
return {"model_name": AzureEndpointClient.NAME, "engine": getattr(self, 'engine')}
|
102 |
-
|
103 |
-
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
104 |
-
"""
|
105 |
-
Preprocess request params.
|
106 |
-
|
107 |
-
Args:
|
108 |
-
request: request params.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
request params.
|
112 |
-
"""
|
113 |
-
# Format for chat model
|
114 |
-
request = copy.deepcopy(request)
|
115 |
-
prompt = request.pop("prompt")
|
116 |
-
data = {"input_data": {"input_string": [{"role": "user", "content": prompt}], "parameters": {"stop":"\n```", "max_tokens": 500}}}
|
117 |
-
|
118 |
-
#body = str(str.encode(json.dumps(data)))
|
119 |
-
return super().preprocess_request_params(data)
|
120 |
-
|
121 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
122 |
-
"""
|
123 |
-
Format response to dict.
|
124 |
-
|
125 |
-
Args:
|
126 |
-
response: response
|
127 |
-
request: request
|
128 |
-
|
129 |
-
Return:
|
130 |
-
response as dict
|
131 |
-
"""
|
132 |
-
new_choices = []
|
133 |
-
response = copy.deepcopy(response)
|
134 |
-
if "output" in response:
|
135 |
-
new_choices.append({"text": response["output"]})
|
136 |
-
else:
|
137 |
-
new_choices.append({"text": ""})
|
138 |
-
response["choices"] = new_choices
|
139 |
-
return super().postprocess_response(response, request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/azureopenai.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
"""Azure client."""
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from typing import Any, Dict, Optional, Type
|
5 |
-
|
6 |
-
from manifest.clients.openai import OPENAI_ENGINES, OpenAIClient
|
7 |
-
from manifest.request import LMRequest, Request
|
8 |
-
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
-
# Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
|
12 |
-
# underscores ("_") may be used, except as ending characters. We create this mapping to
|
13 |
-
# handle difference between Azure and OpenAI
|
14 |
-
AZURE_DEPLOYMENT_NAME_MAPPING = {
|
15 |
-
"gpt-3.5-turbo": "gpt-35-turbo",
|
16 |
-
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
|
17 |
-
}
|
18 |
-
OPENAI_DEPLOYMENT_NAME_MAPPING = {
|
19 |
-
"gpt-35-turbo": "gpt-3.5-turbo",
|
20 |
-
"gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
|
21 |
-
}
|
22 |
-
|
23 |
-
|
24 |
-
class AzureClient(OpenAIClient):
|
25 |
-
"""Azure client."""
|
26 |
-
|
27 |
-
PARAMS = OpenAIClient.PARAMS
|
28 |
-
REQUEST_CLS: Type[Request] = LMRequest
|
29 |
-
NAME = "azureopenai"
|
30 |
-
|
31 |
-
def connect(
|
32 |
-
self,
|
33 |
-
connection_str: Optional[str] = None,
|
34 |
-
client_args: Dict[str, Any] = {},
|
35 |
-
) -> None:
|
36 |
-
"""
|
37 |
-
Connect to the AzureOpenAI server.
|
38 |
-
|
39 |
-
connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
connection_str: connection string.
|
43 |
-
client_args: client arguments.
|
44 |
-
"""
|
45 |
-
self.api_key, self.host = None, None
|
46 |
-
if connection_str:
|
47 |
-
connection_parts = connection_str.split("::")
|
48 |
-
if len(connection_parts) == 1:
|
49 |
-
self.api_key = connection_parts[0]
|
50 |
-
elif len(connection_parts) == 2:
|
51 |
-
self.api_key, self.host = connection_parts
|
52 |
-
else:
|
53 |
-
raise ValueError(
|
54 |
-
"Invalid connection string. "
|
55 |
-
"Must be either AZURE_OPENAI_KEY or "
|
56 |
-
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
57 |
-
)
|
58 |
-
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
|
59 |
-
if self.api_key is None:
|
60 |
-
raise ValueError(
|
61 |
-
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
|
62 |
-
"variable or pass through `client_connection`."
|
63 |
-
)
|
64 |
-
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
65 |
-
if self.host is None:
|
66 |
-
raise ValueError(
|
67 |
-
"Azure Service URL not set "
|
68 |
-
"(e.g. https://openai-azure-service.openai.azure.com/)."
|
69 |
-
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
70 |
-
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
71 |
-
)
|
72 |
-
self.host = self.host.rstrip("/")
|
73 |
-
for key in self.PARAMS:
|
74 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
75 |
-
if getattr(self, "engine") not in OPENAI_ENGINES:
|
76 |
-
raise ValueError(
|
77 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {OPENAI_ENGINES}."
|
78 |
-
)
|
79 |
-
|
80 |
-
def get_generation_url(self) -> str:
|
81 |
-
"""Get generation URL."""
|
82 |
-
engine = getattr(self, "engine")
|
83 |
-
deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
|
84 |
-
return (
|
85 |
-
self.host
|
86 |
-
+ "/openai/deployments/"
|
87 |
-
+ deployment_name
|
88 |
-
+ "/completions?api-version=2023-05-15"
|
89 |
-
)
|
90 |
-
|
91 |
-
def get_generation_header(self) -> Dict[str, str]:
|
92 |
-
"""
|
93 |
-
Get generation header.
|
94 |
-
|
95 |
-
Returns:
|
96 |
-
header.
|
97 |
-
"""
|
98 |
-
return {"api-key": f"{self.api_key}"}
|
99 |
-
|
100 |
-
def get_model_params(self) -> Dict:
|
101 |
-
"""
|
102 |
-
Get model params.
|
103 |
-
|
104 |
-
By getting model params from the server, we can add to request
|
105 |
-
and make sure cache keys are unique to model.
|
106 |
-
|
107 |
-
Returns:
|
108 |
-
model params.
|
109 |
-
"""
|
110 |
-
# IMPORTANT!!!
|
111 |
-
# Azure models are the same as openai models. So we want to unify their
|
112 |
-
# cached. Make sure we retrun the OpenAI name here.
|
113 |
-
return {"model_name": OpenAIClient.NAME, "engine": getattr(self, "engine")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/azureopenai_chat.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
"""Azure client."""
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from typing import Any, Dict, Optional
|
5 |
-
|
6 |
-
from manifest.clients.openai_chat import OPENAICHAT_ENGINES, OpenAIChatClient
|
7 |
-
from manifest.request import LMRequest
|
8 |
-
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
-
|
11 |
-
# Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
|
12 |
-
# underscores ("_") may be used, except as ending characters. We create this mapping to
|
13 |
-
# handle difference between Azure and OpenAI
|
14 |
-
AZURE_DEPLOYMENT_NAME_MAPPING = {
|
15 |
-
"gpt-3.5-turbo": "gpt-35-turbo",
|
16 |
-
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
|
17 |
-
}
|
18 |
-
OPENAI_DEPLOYMENT_NAME_MAPPING = {
|
19 |
-
"gpt-35-turbo": "gpt-3.5-turbo",
|
20 |
-
"gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
|
21 |
-
}
|
22 |
-
|
23 |
-
|
24 |
-
class AzureChatClient(OpenAIChatClient):
|
25 |
-
"""Azure chat client."""
|
26 |
-
|
27 |
-
# User param -> (client param, default value)
|
28 |
-
PARAMS = OpenAIChatClient.PARAMS
|
29 |
-
REQUEST_CLS = LMRequest
|
30 |
-
NAME = "azureopenaichat"
|
31 |
-
IS_CHAT = True
|
32 |
-
|
33 |
-
def connect(
|
34 |
-
self,
|
35 |
-
connection_str: Optional[str] = None,
|
36 |
-
client_args: Dict[str, Any] = {},
|
37 |
-
) -> None:
|
38 |
-
"""
|
39 |
-
Connect to the AzureOpenAI server.
|
40 |
-
|
41 |
-
connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
connection_str: connection string.
|
45 |
-
client_args: client arguments.
|
46 |
-
"""
|
47 |
-
self.api_key, self.host = None, None
|
48 |
-
if connection_str:
|
49 |
-
connection_parts = connection_str.split("::")
|
50 |
-
if len(connection_parts) == 1:
|
51 |
-
self.api_key = connection_parts[0]
|
52 |
-
elif len(connection_parts) == 2:
|
53 |
-
self.api_key, self.host = connection_parts
|
54 |
-
else:
|
55 |
-
raise ValueError(
|
56 |
-
"Invalid connection string. "
|
57 |
-
"Must be either AZURE_OPENAI_KEY or "
|
58 |
-
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
59 |
-
)
|
60 |
-
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
|
61 |
-
if self.api_key is None:
|
62 |
-
raise ValueError(
|
63 |
-
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
|
64 |
-
"variable or pass through `client_connection`."
|
65 |
-
)
|
66 |
-
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
67 |
-
if self.host is None:
|
68 |
-
raise ValueError(
|
69 |
-
"Azure Service URL not set "
|
70 |
-
"(e.g. https://openai-azure-service.openai.azure.com/)."
|
71 |
-
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
|
72 |
-
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
|
73 |
-
)
|
74 |
-
self.host = self.host.rstrip("/")
|
75 |
-
for key in self.PARAMS:
|
76 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
77 |
-
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
|
78 |
-
raise ValueError(
|
79 |
-
f"Invalid engine {getattr(self, 'engine')}. "
|
80 |
-
f"Must be {OPENAICHAT_ENGINES}."
|
81 |
-
)
|
82 |
-
|
83 |
-
def get_generation_url(self) -> str:
|
84 |
-
"""Get generation URL."""
|
85 |
-
engine = getattr(self, "engine")
|
86 |
-
deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
|
87 |
-
return (
|
88 |
-
self.host
|
89 |
-
+ "/openai/deployments/"
|
90 |
-
+ deployment_name
|
91 |
-
+ "/chat/completions?api-version=2023-05-15"
|
92 |
-
)
|
93 |
-
|
94 |
-
def get_generation_header(self) -> Dict[str, str]:
|
95 |
-
"""
|
96 |
-
Get generation header.
|
97 |
-
|
98 |
-
Returns:
|
99 |
-
header.
|
100 |
-
"""
|
101 |
-
return {"api-key": f"{self.api_key}"}
|
102 |
-
|
103 |
-
def get_model_params(self) -> Dict:
|
104 |
-
"""
|
105 |
-
Get model params.
|
106 |
-
|
107 |
-
By getting model params from the server, we can add to request
|
108 |
-
and make sure cache keys are unique to model.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
model params.
|
112 |
-
"""
|
113 |
-
# IMPORTANT!!!
|
114 |
-
# Azure models are the same as openai models. So we want to unify their
|
115 |
-
# cached. Make sure we retrun the OpenAI name here.
|
116 |
-
return {"model_name": OpenAIChatClient.NAME, "engine": getattr(self, "engine")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/client.py
DELETED
@@ -1,699 +0,0 @@
|
|
1 |
-
"""Client class."""
|
2 |
-
import asyncio
|
3 |
-
import copy
|
4 |
-
import json
|
5 |
-
import logging
|
6 |
-
import math
|
7 |
-
from abc import ABC, abstractmethod
|
8 |
-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, cast
|
9 |
-
|
10 |
-
import aiohttp
|
11 |
-
import requests
|
12 |
-
import tqdm.asyncio
|
13 |
-
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
|
14 |
-
|
15 |
-
from manifest.request import (
|
16 |
-
DEFAULT_REQUEST_KEYS,
|
17 |
-
NOT_CACHE_KEYS,
|
18 |
-
LMChatRequest,
|
19 |
-
LMRequest,
|
20 |
-
LMScoreRequest,
|
21 |
-
Request,
|
22 |
-
)
|
23 |
-
from manifest.response import (
|
24 |
-
RESPONSE_CONSTRUCTORS,
|
25 |
-
ArrayModelChoice,
|
26 |
-
LMModelChoice,
|
27 |
-
ModelChoices,
|
28 |
-
Response,
|
29 |
-
Usage,
|
30 |
-
Usages,
|
31 |
-
)
|
32 |
-
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
|
35 |
-
ATTEMPTS_BEFORE_STOP = 4
|
36 |
-
ATTEMPTS_TIMEOUT = 30
|
37 |
-
# http_status mainly for azure and e.code mainly for openai usage
|
38 |
-
# e.http_status == 408 occurs when Azure times out
|
39 |
-
# e.code == 429 rate lime
|
40 |
-
# e.code == 500 or 502 occurs when server error
|
41 |
-
API_ERROR_CODE = {408, 429, 500, 502, 520, 524}
|
42 |
-
|
43 |
-
|
44 |
-
def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
|
45 |
-
"""Return whether to retry if ratelimited."""
|
46 |
-
try:
|
47 |
-
if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
|
48 |
-
exception = cast(
|
49 |
-
requests.exceptions.HTTPError, retry_base.outcome.exception()
|
50 |
-
)
|
51 |
-
# 500 is a server error, 429 is a rate limit error
|
52 |
-
if exception.response.status_code in API_ERROR_CODE: # type: ignore
|
53 |
-
return True
|
54 |
-
except Exception:
|
55 |
-
pass
|
56 |
-
return True
|
57 |
-
|
58 |
-
|
59 |
-
def return_error_response(retry_state: RetryCallState) -> dict:
|
60 |
-
"""Return error response if all retries failed."""
|
61 |
-
request_params = retry_state.args[1]
|
62 |
-
number_of_prompts = (
|
63 |
-
len(request_params["prompt"])
|
64 |
-
if "prompt" in request_params
|
65 |
-
else len(request_params["messages"])
|
66 |
-
)
|
67 |
-
return {
|
68 |
-
"choices": [],
|
69 |
-
"usage": {
|
70 |
-
"total_tokens": 0,
|
71 |
-
"prompt_tokens": 0,
|
72 |
-
"completion_tokens": 0,
|
73 |
-
},
|
74 |
-
"errors": [str(retry_state.outcome.exception())] * number_of_prompts,
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
class Client(ABC):
|
79 |
-
"""Client class."""
|
80 |
-
|
81 |
-
# Must be overridden by child class
|
82 |
-
PARAMS: Dict[str, Tuple[str, Any]] = {}
|
83 |
-
REQUEST_CLS = Request
|
84 |
-
NAME: str = None
|
85 |
-
IS_CHAT: bool = False
|
86 |
-
|
87 |
-
def __init__(
|
88 |
-
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
|
89 |
-
):
|
90 |
-
"""
|
91 |
-
Initialize client.
|
92 |
-
|
93 |
-
kwargs are passed to client as default parameters.
|
94 |
-
|
95 |
-
For clients like OpenAI that do not require a connection,
|
96 |
-
the connection_str can be None.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
connection_str: connection string for client.
|
100 |
-
client_args: client arguments.
|
101 |
-
"""
|
102 |
-
self.connect(connection_str, client_args)
|
103 |
-
|
104 |
-
@abstractmethod
|
105 |
-
def connect(
|
106 |
-
self, connection_str: Optional[str], client_args: Dict[str, Any]
|
107 |
-
) -> None:
|
108 |
-
"""
|
109 |
-
Connect to client.
|
110 |
-
|
111 |
-
Override in child client class.
|
112 |
-
Args:
|
113 |
-
connection_str: connection string.
|
114 |
-
"""
|
115 |
-
raise NotImplementedError()
|
116 |
-
|
117 |
-
@abstractmethod
|
118 |
-
def close(self) -> None:
|
119 |
-
"""Close the client.
|
120 |
-
|
121 |
-
Override in child client class.
|
122 |
-
"""
|
123 |
-
raise NotImplementedError()
|
124 |
-
|
125 |
-
@abstractmethod
|
126 |
-
def get_generation_url(self) -> str:
|
127 |
-
"""Get generation URL.
|
128 |
-
|
129 |
-
Override in child client class.
|
130 |
-
"""
|
131 |
-
raise NotImplementedError()
|
132 |
-
|
133 |
-
@abstractmethod
|
134 |
-
def get_generation_header(self) -> Dict[str, str]:
|
135 |
-
"""
|
136 |
-
Get generation header.
|
137 |
-
|
138 |
-
Override in child client class.
|
139 |
-
Returns:
|
140 |
-
header.
|
141 |
-
"""
|
142 |
-
raise NotImplementedError()
|
143 |
-
|
144 |
-
@abstractmethod
|
145 |
-
def supports_batch_inference(self) -> bool:
|
146 |
-
"""Return whether the client supports batch inference.
|
147 |
-
|
148 |
-
Override in child client class.
|
149 |
-
"""
|
150 |
-
raise NotImplementedError()
|
151 |
-
|
152 |
-
@abstractmethod
|
153 |
-
def supports_streaming_inference(self) -> bool:
|
154 |
-
"""Return whether the client supports streaming inference.
|
155 |
-
|
156 |
-
Override in child client class.
|
157 |
-
"""
|
158 |
-
raise NotImplementedError()
|
159 |
-
|
160 |
-
@abstractmethod
|
161 |
-
def get_model_params(self) -> Dict:
|
162 |
-
"""
|
163 |
-
Get model params.
|
164 |
-
|
165 |
-
By getting model params from the server, we can add to request
|
166 |
-
and make sure cache keys are unique to model.
|
167 |
-
|
168 |
-
Override in child client class.
|
169 |
-
Returns:
|
170 |
-
model params.
|
171 |
-
"""
|
172 |
-
raise NotImplementedError()
|
173 |
-
|
174 |
-
def get_tokenizer(self, model: str) -> Tuple[Any, int]:
|
175 |
-
"""Get tokenizer for model.
|
176 |
-
|
177 |
-
Override in child client class. Return None, -1 if not supported
|
178 |
-
or no prompt truncation required.
|
179 |
-
Returns:
|
180 |
-
tokenizer: tokenizer with encoder and decode
|
181 |
-
max_length: max length of model
|
182 |
-
"""
|
183 |
-
return None, -1
|
184 |
-
|
185 |
-
def get_model_inputs(self) -> List:
|
186 |
-
"""
|
187 |
-
Get allowable model inputs.
|
188 |
-
|
189 |
-
Returns:
|
190 |
-
model inputs.
|
191 |
-
"""
|
192 |
-
return list(self.PARAMS.keys())
|
193 |
-
|
194 |
-
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
|
195 |
-
"""Split usage into list of usages for each prompt."""
|
196 |
-
# TODO: add this in using default tokenizer
|
197 |
-
return []
|
198 |
-
|
199 |
-
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
200 |
-
"""
|
201 |
-
Preprocess request params.
|
202 |
-
|
203 |
-
Override in child client class to reformat requests to model.
|
204 |
-
|
205 |
-
Args:
|
206 |
-
request: request params.
|
207 |
-
|
208 |
-
Returns:
|
209 |
-
request params.
|
210 |
-
"""
|
211 |
-
return request
|
212 |
-
|
213 |
-
def postprocess_response(
|
214 |
-
self, response: Dict[str, Any], request: Dict[str, Any]
|
215 |
-
) -> Dict[str, Any]:
|
216 |
-
"""
|
217 |
-
Postprocess and validate response as dict.
|
218 |
-
|
219 |
-
Override in child client class to reform model responses.
|
220 |
-
|
221 |
-
Args:
|
222 |
-
response: response
|
223 |
-
request: request
|
224 |
-
|
225 |
-
Return:
|
226 |
-
response as dict
|
227 |
-
"""
|
228 |
-
if "choices" not in response:
|
229 |
-
raise ValueError(f"Invalid response: {response}")
|
230 |
-
if "usage" in response:
|
231 |
-
# Handle splitting the usages for batch requests
|
232 |
-
if len(response["choices"]) == 1:
|
233 |
-
if isinstance(response["usage"], list):
|
234 |
-
response["usage"] = response["usage"][0]
|
235 |
-
response["usage"] = [response["usage"]]
|
236 |
-
else:
|
237 |
-
# Try to split usage
|
238 |
-
split_usage = self.split_usage(request, response["choices"])
|
239 |
-
if split_usage:
|
240 |
-
response["usage"] = split_usage
|
241 |
-
return response
|
242 |
-
|
243 |
-
def get_request(
|
244 |
-
self, prompt: Union[str, List[str]], request_args: Dict[str, Any]
|
245 |
-
) -> Request:
|
246 |
-
"""
|
247 |
-
Parse model kwargs to request.
|
248 |
-
|
249 |
-
Args:
|
250 |
-
prompt: prompt.
|
251 |
-
request_args: request arguments.
|
252 |
-
|
253 |
-
Returns:
|
254 |
-
request.
|
255 |
-
"""
|
256 |
-
params = {"prompt": prompt}
|
257 |
-
# Adds default values from self.PARAMS if not in request_args
|
258 |
-
for key in self.PARAMS:
|
259 |
-
params[key] = request_args.pop(key, getattr(self, key))
|
260 |
-
# Allows for overriding DEFAULT_REQUEST_KEYS even if they are not
|
261 |
-
# in self.PARAMS. Note that DEFAULT_REQUEST_KEYS match the default
|
262 |
-
# values in Request.
|
263 |
-
for key in DEFAULT_REQUEST_KEYS:
|
264 |
-
if key not in params and key in request_args:
|
265 |
-
params[key] = request_args.pop(key)
|
266 |
-
return self.REQUEST_CLS(**params) # type: ignore
|
267 |
-
|
268 |
-
def _get_request_params(self, request: Request) -> Dict[str, Any]:
|
269 |
-
"""Get request params.
|
270 |
-
|
271 |
-
Add default keys that we need for requests such as batch_size.
|
272 |
-
We drop these before sending to the model.
|
273 |
-
"""
|
274 |
-
params_to_add = DEFAULT_REQUEST_KEYS.copy()
|
275 |
-
# This will override DEFAULT_REQUEST_KEYS with those in PARAMS
|
276 |
-
params_to_add.update(self.PARAMS)
|
277 |
-
# to_dict will handle parameter renaming but not any
|
278 |
-
# default value handling - that is done in get_request()
|
279 |
-
request_params = request.to_dict(params_to_add)
|
280 |
-
return request_params
|
281 |
-
|
282 |
-
def get_cache_key(self, request: Request) -> Dict[str, Any]:
|
283 |
-
"""Get cache key for request.
|
284 |
-
|
285 |
-
Skip keys that are not cache keys such as batch_size.
|
286 |
-
"""
|
287 |
-
request_params = self._get_request_params(request)
|
288 |
-
for key in NOT_CACHE_KEYS:
|
289 |
-
request_params.pop(key, None)
|
290 |
-
# Make sure to add model params and request class
|
291 |
-
request_params.update(self.get_model_params())
|
292 |
-
request_params["request_cls"] = request.__class__.__name__
|
293 |
-
return request_params
|
294 |
-
|
295 |
-
def _split_requests(
|
296 |
-
self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt"
|
297 |
-
) -> List[Dict[str, Any]]:
|
298 |
-
"""Split request into batch_sized request.
|
299 |
-
|
300 |
-
Args:
|
301 |
-
request_params: request params.
|
302 |
-
batch_size: batch size for requests.
|
303 |
-
key: key to batch over
|
304 |
-
|
305 |
-
Returns:
|
306 |
-
list of request params.
|
307 |
-
"""
|
308 |
-
data = copy.deepcopy(request_params[key])
|
309 |
-
data_size = len(request_params[key])
|
310 |
-
request_params_list = []
|
311 |
-
for i in range(0, data_size, batch_size):
|
312 |
-
params = copy.deepcopy(request_params)
|
313 |
-
params[key] = data[i] if batch_size == 1 else data[i : i + batch_size]
|
314 |
-
request_params_list.append(params)
|
315 |
-
return request_params_list
|
316 |
-
|
317 |
-
def _get_model_choices(self, response: Dict) -> ModelChoices:
|
318 |
-
"""Format response to ModelChoices."""
|
319 |
-
# Array or text response
|
320 |
-
response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
|
321 |
-
if response_type == "array":
|
322 |
-
choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
|
323 |
-
ArrayModelChoice(**choice) for choice in response["choices"]
|
324 |
-
]
|
325 |
-
else:
|
326 |
-
choices = [LMModelChoice(**choice) for choice in response["choices"]]
|
327 |
-
return ModelChoices(choices=choices)
|
328 |
-
|
329 |
-
def _stitch_responses(self, request: Request, responses: List[Dict]) -> Response:
|
330 |
-
"""Stitch responses together.
|
331 |
-
|
332 |
-
Useful for batch requests.
|
333 |
-
"""
|
334 |
-
choices = []
|
335 |
-
usages = []
|
336 |
-
for res_dict in responses:
|
337 |
-
choices.extend(res_dict["choices"])
|
338 |
-
if "usage" in res_dict:
|
339 |
-
usages.extend(res_dict["usage"])
|
340 |
-
final_response_dict = {"choices": choices}
|
341 |
-
final_usages = None
|
342 |
-
if usages:
|
343 |
-
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
|
344 |
-
# TODO: Add usage based on tokenizer
|
345 |
-
return Response(
|
346 |
-
self._get_model_choices(final_response_dict),
|
347 |
-
cached=False,
|
348 |
-
request=request,
|
349 |
-
usages=final_usages,
|
350 |
-
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
|
351 |
-
)
|
352 |
-
|
353 |
-
def _verify_request_lengths(
|
354 |
-
self, request: Dict[str, Any], model: str, max_tokens: int
|
355 |
-
) -> None:
|
356 |
-
"""Verify that the request length is not too long."""
|
357 |
-
encoder, max_length = self.get_tokenizer(model)
|
358 |
-
if not encoder or max_length < 0:
|
359 |
-
return
|
360 |
-
if isinstance(request["prompt"], str):
|
361 |
-
prompts = [request["prompt"]]
|
362 |
-
else:
|
363 |
-
prompts = request["prompt"]
|
364 |
-
for i in range(len(prompts)):
|
365 |
-
prompt = prompts[i]
|
366 |
-
encoded_prompt = encoder.encode(prompt)
|
367 |
-
if len(encoded_prompt) + max_tokens > max_length:
|
368 |
-
logger.warning(
|
369 |
-
f"Prompt {prompt} is too long for model {model}. "
|
370 |
-
"Truncating prompt from left."
|
371 |
-
)
|
372 |
-
# -20 to be safe
|
373 |
-
prompt = encoder.decode(
|
374 |
-
encoded_prompt[-int(max_length - max_tokens - 20) :]
|
375 |
-
)
|
376 |
-
prompts[i] = prompt
|
377 |
-
if isinstance(request["prompt"], str):
|
378 |
-
request["prompt"] = prompts[0]
|
379 |
-
else:
|
380 |
-
request["prompt"] = prompts
|
381 |
-
|
382 |
-
@retry(
|
383 |
-
reraise=True,
|
384 |
-
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
|
385 |
-
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
|
386 |
-
)
|
387 |
-
def _run_completion(
|
388 |
-
self, request_params: Dict[str, Any], retry_timeout: int
|
389 |
-
) -> Dict:
|
390 |
-
"""Execute completion request.
|
391 |
-
|
392 |
-
Args:
|
393 |
-
request_params: request params.
|
394 |
-
retry_timeout: retry timeout.
|
395 |
-
|
396 |
-
Returns:
|
397 |
-
response as dict.
|
398 |
-
"""
|
399 |
-
request_params = self.preprocess_request_params(request_params)
|
400 |
-
print(request_params)
|
401 |
-
post_str = self.get_generation_url()
|
402 |
-
res = requests.post(
|
403 |
-
post_str,
|
404 |
-
headers=self.get_generation_header(),
|
405 |
-
json=request_params,
|
406 |
-
timeout=retry_timeout,
|
407 |
-
)
|
408 |
-
try:
|
409 |
-
res.raise_for_status()
|
410 |
-
except requests.exceptions.HTTPError as e:
|
411 |
-
logger.warning(
|
412 |
-
str(e)
|
413 |
-
)
|
414 |
-
raise Exception()
|
415 |
-
return self.postprocess_response(res.json(), request_params)
|
416 |
-
|
417 |
-
@retry(
|
418 |
-
reraise=True,
|
419 |
-
retry=retry_if_ratelimit,
|
420 |
-
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
|
421 |
-
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
|
422 |
-
)
|
423 |
-
async def _arun_completion(
|
424 |
-
self, request_params: Dict[str, Any], retry_timeout: int
|
425 |
-
) -> Dict:
|
426 |
-
"""Async execute completion request.
|
427 |
-
|
428 |
-
Args:
|
429 |
-
request_params: request params.
|
430 |
-
retry_timeout: retry timeout.
|
431 |
-
|
432 |
-
Returns:
|
433 |
-
response as dict.
|
434 |
-
"""
|
435 |
-
request_params = self.preprocess_request_params(request_params)
|
436 |
-
post_str = self.get_generation_url()
|
437 |
-
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
|
438 |
-
async with session.post(
|
439 |
-
post_str,
|
440 |
-
headers=self.get_generation_header(),
|
441 |
-
json=request_params,
|
442 |
-
timeout=retry_timeout,
|
443 |
-
) as res:
|
444 |
-
res.raise_for_status()
|
445 |
-
res_json = await res.json(content_type=None)
|
446 |
-
return self.postprocess_response(res_json, request_params)
|
447 |
-
|
448 |
-
@retry(
|
449 |
-
reraise=True,
|
450 |
-
retry=retry_if_ratelimit,
|
451 |
-
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
|
452 |
-
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
|
453 |
-
)
|
454 |
-
def _run_streaming_completion(
|
455 |
-
self, request_params: Dict[str, Any], retry_timeout: int
|
456 |
-
) -> Generator[Dict, None, None]:
|
457 |
-
"""Execute completion request streaming.
|
458 |
-
|
459 |
-
Args:
|
460 |
-
request_params: request params.
|
461 |
-
retry_timeout: retry timeout.
|
462 |
-
|
463 |
-
Returns:
|
464 |
-
response as dict.
|
465 |
-
"""
|
466 |
-
request_params = self.preprocess_request_params(request_params)
|
467 |
-
request_params["stream"] = True
|
468 |
-
post_str = self.get_generation_url()
|
469 |
-
res_iter = requests.post(
|
470 |
-
post_str,
|
471 |
-
headers=self.get_generation_header(),
|
472 |
-
json=request_params,
|
473 |
-
timeout=retry_timeout,
|
474 |
-
stream=True,
|
475 |
-
)
|
476 |
-
for res_token in res_iter.iter_lines():
|
477 |
-
if res_token:
|
478 |
-
decoded_res_token = res_token.decode("utf-8")
|
479 |
-
decoded_res_token = decoded_res_token.replace("data: ", "")
|
480 |
-
if decoded_res_token == "[DONE]":
|
481 |
-
break
|
482 |
-
try:
|
483 |
-
decoded_res_token_dct = json.loads(decoded_res_token)
|
484 |
-
postprocess_res_token_dct = self.postprocess_response(
|
485 |
-
decoded_res_token_dct, request_params
|
486 |
-
)
|
487 |
-
# If nothing is returned, skip
|
488 |
-
if (
|
489 |
-
not postprocess_res_token_dct
|
490 |
-
or not postprocess_res_token_dct["choices"]
|
491 |
-
):
|
492 |
-
continue
|
493 |
-
yield postprocess_res_token_dct
|
494 |
-
except Exception as e:
|
495 |
-
raise e
|
496 |
-
|
497 |
-
def run_request(self, request: Request) -> Response:
|
498 |
-
"""
|
499 |
-
Run request.
|
500 |
-
|
501 |
-
Args:
|
502 |
-
request: request.
|
503 |
-
|
504 |
-
Returns:
|
505 |
-
response.
|
506 |
-
"""
|
507 |
-
# Make everything list for consistency
|
508 |
-
if isinstance(request.prompt, list):
|
509 |
-
prompt_list = request.prompt
|
510 |
-
else:
|
511 |
-
prompt_list = [request.prompt]
|
512 |
-
|
513 |
-
request_params = self._get_request_params(request)
|
514 |
-
# Set the params as a list. Do not set the request
|
515 |
-
# object itself as the cache will then store it as a
|
516 |
-
# list which is inconsistent with the request input.
|
517 |
-
request_params["prompt"] = prompt_list
|
518 |
-
|
519 |
-
# If batch_size is not set, set it to 1
|
520 |
-
batch_size = request_params.pop("batch_size") or 1
|
521 |
-
if not self.supports_batch_inference() and batch_size != 1:
|
522 |
-
logger.warning(
|
523 |
-
f"{self.__class__.__name__} does not support batch inference."
|
524 |
-
" Setting batch size to 1"
|
525 |
-
)
|
526 |
-
batch_size = 1
|
527 |
-
|
528 |
-
# Take the default keys we need and drop the rest as they
|
529 |
-
# are not part of the model request.
|
530 |
-
retry_timeout = request_params.pop("client_timeout")
|
531 |
-
for key in DEFAULT_REQUEST_KEYS:
|
532 |
-
request_params.pop(key, None)
|
533 |
-
|
534 |
-
# Make sure requests are in the request length
|
535 |
-
# If no tokenizer is set or not LM request, this
|
536 |
-
# will do nothing
|
537 |
-
if isinstance(request, LMRequest):
|
538 |
-
self._verify_request_lengths(
|
539 |
-
request_params, model=request.engine, max_tokens=request.max_tokens
|
540 |
-
)
|
541 |
-
|
542 |
-
# Batch requests
|
543 |
-
num_batches = len(prompt_list) // batch_size
|
544 |
-
if len(prompt_list) % batch_size != 0:
|
545 |
-
batch_size = int(math.ceil(len(prompt_list) / (num_batches + 1)))
|
546 |
-
request_batches = self._split_requests(request_params, batch_size)
|
547 |
-
|
548 |
-
response_dicts = [
|
549 |
-
self._run_completion(batch, retry_timeout) for batch in request_batches
|
550 |
-
]
|
551 |
-
# Flatten responses
|
552 |
-
return self._stitch_responses(request, response_dicts)
|
553 |
-
|
554 |
-
async def arun_batch_request(
|
555 |
-
self, request: Request, verbose: bool = False
|
556 |
-
) -> Response:
|
557 |
-
"""
|
558 |
-
Run async request.
|
559 |
-
|
560 |
-
Args:
|
561 |
-
request: request.s
|
562 |
-
|
563 |
-
Returns:
|
564 |
-
response.
|
565 |
-
"""
|
566 |
-
required_batch_size = None
|
567 |
-
if not self.supports_batch_inference():
|
568 |
-
required_batch_size = 1
|
569 |
-
if not isinstance(request.prompt, list):
|
570 |
-
raise AssertionError(
|
571 |
-
"request.prompt must be a list for async batch inference."
|
572 |
-
)
|
573 |
-
|
574 |
-
request_params = self._get_request_params(request)
|
575 |
-
# Take the default keys we need and drop the rest as they
|
576 |
-
# are not part of the model request.
|
577 |
-
retry_timeout = request_params.pop("client_timeout")
|
578 |
-
batch_size = request_params.pop("batch_size")
|
579 |
-
batch_size = required_batch_size or batch_size
|
580 |
-
for key in DEFAULT_REQUEST_KEYS:
|
581 |
-
request_params.pop(key, None)
|
582 |
-
|
583 |
-
# Make sure requests are in the request length
|
584 |
-
# If no tokenizer is set or not LM request, this
|
585 |
-
# will do nothing
|
586 |
-
if isinstance(request, LMRequest):
|
587 |
-
self._verify_request_lengths(
|
588 |
-
request_params, model=request.engine, max_tokens=request.max_tokens
|
589 |
-
)
|
590 |
-
|
591 |
-
# Batch requests
|
592 |
-
num_batches = len(request.prompt) // batch_size
|
593 |
-
if len(request.prompt) % batch_size != 0:
|
594 |
-
batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1)))
|
595 |
-
|
596 |
-
request_batches = self._split_requests(request_params, batch_size)
|
597 |
-
all_tasks = [
|
598 |
-
asyncio.create_task(self._arun_completion(batch, retry_timeout))
|
599 |
-
for batch in request_batches
|
600 |
-
]
|
601 |
-
responses = await tqdm.asyncio.tqdm.gather(*all_tasks, disable=not verbose)
|
602 |
-
# Flatten responses
|
603 |
-
return self._stitch_responses(request, responses)
|
604 |
-
|
605 |
-
def run_chat_request(
|
606 |
-
self,
|
607 |
-
request: LMChatRequest,
|
608 |
-
) -> Response:
|
609 |
-
"""
|
610 |
-
Get the response from chat model.
|
611 |
-
|
612 |
-
Args:
|
613 |
-
request: request.
|
614 |
-
|
615 |
-
Returns:
|
616 |
-
response.
|
617 |
-
"""
|
618 |
-
request_params = self._get_request_params(request)
|
619 |
-
# Take the default keys we need and drop the rest as they
|
620 |
-
# are not part of the model request.
|
621 |
-
retry_timeout = request_params.pop("client_timeout")
|
622 |
-
for key in DEFAULT_REQUEST_KEYS:
|
623 |
-
request_params.pop(key, None)
|
624 |
-
|
625 |
-
# Make sure requests are in the request length
|
626 |
-
# If no tokenizer is set or not LM request, this
|
627 |
-
# will do nothing
|
628 |
-
self._verify_request_lengths(
|
629 |
-
request_params, model=request.engine, max_tokens=request.max_tokens
|
630 |
-
)
|
631 |
-
|
632 |
-
response_dict = self._run_completion(request_params, retry_timeout)
|
633 |
-
usages = None
|
634 |
-
if "usage" in response_dict:
|
635 |
-
usages = [Usage(**usage) for usage in response_dict["usage"]]
|
636 |
-
|
637 |
-
return Response(
|
638 |
-
response=self._get_model_choices(response_dict),
|
639 |
-
cached=False,
|
640 |
-
request=request,
|
641 |
-
usages=Usages(usages=usages) if usages else None,
|
642 |
-
**RESPONSE_CONSTRUCTORS[LMChatRequest], # type: ignore
|
643 |
-
)
|
644 |
-
|
645 |
-
def run_streaming_request(
|
646 |
-
self, request: Request
|
647 |
-
) -> Generator[Response, None, None]:
|
648 |
-
"""
|
649 |
-
Run streaming request.
|
650 |
-
|
651 |
-
Args:
|
652 |
-
request: request.
|
653 |
-
|
654 |
-
Returns:
|
655 |
-
response.
|
656 |
-
"""
|
657 |
-
if not isinstance(request.prompt, str):
|
658 |
-
raise ValueError("Streaming requests must have a single prompt.")
|
659 |
-
if not self.supports_streaming_inference():
|
660 |
-
raise ValueError(
|
661 |
-
f"{self.__class__.__name__} does not support streaming inference."
|
662 |
-
)
|
663 |
-
request_params = self._get_request_params(request)
|
664 |
-
|
665 |
-
# Take the default keys we need and drop the rest as they
|
666 |
-
# are not part of the model request.
|
667 |
-
retry_timeout = request_params.pop("client_timeout")
|
668 |
-
for key in DEFAULT_REQUEST_KEYS:
|
669 |
-
request_params.pop(key, None)
|
670 |
-
|
671 |
-
# Make sure requests are in the request length
|
672 |
-
# If no tokenizer is set or not LM request, this
|
673 |
-
# will do nothing
|
674 |
-
if isinstance(request, LMRequest):
|
675 |
-
self._verify_request_lengths(
|
676 |
-
request_params, model=request.engine, max_tokens=request.max_tokens
|
677 |
-
)
|
678 |
-
|
679 |
-
for token_response in self._run_streaming_completion(
|
680 |
-
request_params, retry_timeout
|
681 |
-
):
|
682 |
-
yield self._stitch_responses(request, [token_response])
|
683 |
-
|
684 |
-
def run_score_prompt_request(
|
685 |
-
self,
|
686 |
-
request: LMScoreRequest,
|
687 |
-
) -> Response:
|
688 |
-
"""
|
689 |
-
Get the logit score of the prompt via a forward pass of the model.
|
690 |
-
|
691 |
-
Args:
|
692 |
-
request: request.
|
693 |
-
|
694 |
-
Returns:
|
695 |
-
response.
|
696 |
-
"""
|
697 |
-
raise NotImplementedError(
|
698 |
-
f"{self.__class__.__name__} does not support prompt scoring request."
|
699 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/cohere.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
"""Cohere client."""
|
2 |
-
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
from typing import Any, Dict, Optional
|
6 |
-
|
7 |
-
from manifest.clients.client import Client
|
8 |
-
from manifest.request import LMRequest
|
9 |
-
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
COHERE_MODELS = {"small", "medium", "large", "xlarge"}
|
13 |
-
|
14 |
-
|
15 |
-
class CohereClient(Client):
|
16 |
-
"""Cohere client."""
|
17 |
-
|
18 |
-
# Params are defined in https://docs.cohere.ai/generate-reference
|
19 |
-
PARAMS = {
|
20 |
-
"engine": ("model", "xlarge"),
|
21 |
-
"max_tokens": ("max_tokens", 20),
|
22 |
-
"temperature": ("temperature", 0.75),
|
23 |
-
"n": ("num_generations", 1),
|
24 |
-
"top_k": ("k", 0),
|
25 |
-
"top_p": ("p", 0.75),
|
26 |
-
"frequency_penalty": ("frequency_penalty", 0.0),
|
27 |
-
"presence_penalty": ("presence_penalty", 0.0),
|
28 |
-
"stop_sequences": ("stop_sequences", None),
|
29 |
-
}
|
30 |
-
REQUEST_CLS = LMRequest
|
31 |
-
NAME = "cohere"
|
32 |
-
|
33 |
-
def connect(
|
34 |
-
self,
|
35 |
-
connection_str: Optional[str] = None,
|
36 |
-
client_args: Dict[str, Any] = {},
|
37 |
-
) -> None:
|
38 |
-
"""
|
39 |
-
Connect to the Cohere server.
|
40 |
-
|
41 |
-
connection_str is passed as default COHERE_API_KEY if variable not set.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
connection_str: connection string.
|
45 |
-
client_args: client arguments.
|
46 |
-
"""
|
47 |
-
self.api_key = connection_str or os.environ.get("COHERE_API_KEY")
|
48 |
-
if self.api_key is None:
|
49 |
-
raise ValueError(
|
50 |
-
"Cohere API key not set. Set COHERE_API_KEY environment "
|
51 |
-
"variable or pass through `client_connection`."
|
52 |
-
)
|
53 |
-
self.host = "https://api.cohere.ai"
|
54 |
-
for key in self.PARAMS:
|
55 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
56 |
-
if getattr(self, "engine") not in COHERE_MODELS:
|
57 |
-
raise ValueError(
|
58 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {COHERE_MODELS}."
|
59 |
-
)
|
60 |
-
|
61 |
-
def close(self) -> None:
|
62 |
-
"""Close the client."""
|
63 |
-
|
64 |
-
def get_generation_url(self) -> str:
|
65 |
-
"""Get generation URL."""
|
66 |
-
return self.host + "/generate"
|
67 |
-
|
68 |
-
def get_generation_header(self) -> Dict[str, str]:
|
69 |
-
"""
|
70 |
-
Get generation header.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
header.
|
74 |
-
"""
|
75 |
-
return {
|
76 |
-
"Cohere-Version": "2021-11-08",
|
77 |
-
"Authorization": f"Bearer {self.api_key}",
|
78 |
-
}
|
79 |
-
|
80 |
-
def supports_batch_inference(self) -> bool:
|
81 |
-
"""Return whether the client supports batch inference."""
|
82 |
-
return False
|
83 |
-
|
84 |
-
def supports_streaming_inference(self) -> bool:
|
85 |
-
"""Return whether the client supports streaming inference.
|
86 |
-
|
87 |
-
Override in child client class.
|
88 |
-
"""
|
89 |
-
return False
|
90 |
-
|
91 |
-
def get_model_params(self) -> Dict:
|
92 |
-
"""
|
93 |
-
Get model params.
|
94 |
-
|
95 |
-
By getting model params from the server, we can add to request
|
96 |
-
and make sure cache keys are unique to model.
|
97 |
-
|
98 |
-
Returns:
|
99 |
-
model params.
|
100 |
-
"""
|
101 |
-
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
102 |
-
|
103 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
104 |
-
"""
|
105 |
-
Format response to dict.
|
106 |
-
|
107 |
-
Args:
|
108 |
-
response: response
|
109 |
-
request: request
|
110 |
-
|
111 |
-
Return:
|
112 |
-
response as dict
|
113 |
-
"""
|
114 |
-
return {
|
115 |
-
"object": "text_completion",
|
116 |
-
"model": getattr(self, "engine"),
|
117 |
-
"choices": [
|
118 |
-
{
|
119 |
-
"text": item["text"],
|
120 |
-
"text_logprob": item.get("likelihood", None),
|
121 |
-
"token_logprobs": item.get("token_likelihoods", None),
|
122 |
-
}
|
123 |
-
for item in response["generations"]
|
124 |
-
],
|
125 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/diffuser.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
"""Diffuser client."""
|
2 |
-
import logging
|
3 |
-
from functools import lru_cache
|
4 |
-
from typing import Any, Dict, Optional
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import requests
|
8 |
-
|
9 |
-
from manifest.clients.client import Client
|
10 |
-
from manifest.request import DiffusionRequest
|
11 |
-
|
12 |
-
logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
-
|
15 |
-
class DiffuserClient(Client):
|
16 |
-
"""Diffuser client."""
|
17 |
-
|
18 |
-
# User param -> (client param, default value)
|
19 |
-
PARAMS = {
|
20 |
-
"num_inference_steps": ("num_inference_steps", 50),
|
21 |
-
"height": ("height", 512),
|
22 |
-
"width": ("width", 512),
|
23 |
-
"n": ("num_images_per_prompt", 1),
|
24 |
-
"guidance_scale": ("guidance_scale", 7.5),
|
25 |
-
"eta": ("eta", 0.0),
|
26 |
-
}
|
27 |
-
REQUEST_CLS = DiffusionRequest
|
28 |
-
NAME = "diffuser"
|
29 |
-
|
30 |
-
def connect(
|
31 |
-
self,
|
32 |
-
connection_str: Optional[str] = None,
|
33 |
-
client_args: Dict[str, Any] = {},
|
34 |
-
) -> None:
|
35 |
-
"""
|
36 |
-
Connect to the Diffuser url.
|
37 |
-
|
38 |
-
Arsg:
|
39 |
-
connection_str: connection string.
|
40 |
-
client_args: client arguments.
|
41 |
-
"""
|
42 |
-
self.host = connection_str.rstrip("/")
|
43 |
-
for key in self.PARAMS:
|
44 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
45 |
-
self.model_params = self.get_model_params()
|
46 |
-
|
47 |
-
def to_numpy(self, image: np.ndarray) -> np.ndarray:
|
48 |
-
"""Convert a numpy image to a PIL image.
|
49 |
-
|
50 |
-
Adapted from https://github.com/huggingface/diffusers/blob/src/diffusers/pipelines/pipeline_utils.py#L808 # noqa: E501
|
51 |
-
"""
|
52 |
-
image = (image * 255).round().astype("uint8")
|
53 |
-
return image
|
54 |
-
|
55 |
-
def close(self) -> None:
|
56 |
-
"""Close the client."""
|
57 |
-
pass
|
58 |
-
|
59 |
-
def get_generation_url(self) -> str:
|
60 |
-
"""Get generation URL."""
|
61 |
-
return self.host + "/completions"
|
62 |
-
|
63 |
-
def get_generation_header(self) -> Dict[str, str]:
|
64 |
-
"""
|
65 |
-
Get generation header.
|
66 |
-
|
67 |
-
Returns:
|
68 |
-
header.
|
69 |
-
"""
|
70 |
-
return {}
|
71 |
-
|
72 |
-
def supports_batch_inference(self) -> bool:
|
73 |
-
"""Return whether the client supports batch inference."""
|
74 |
-
return True
|
75 |
-
|
76 |
-
def supports_streaming_inference(self) -> bool:
|
77 |
-
"""Return whether the client supports streaming inference.
|
78 |
-
|
79 |
-
Override in child client class.
|
80 |
-
"""
|
81 |
-
return False
|
82 |
-
|
83 |
-
@lru_cache(maxsize=1)
|
84 |
-
def get_model_params(self) -> Dict:
|
85 |
-
"""
|
86 |
-
Get model params.
|
87 |
-
|
88 |
-
By getting model params from the server, we can add to request
|
89 |
-
and make sure cache keys are unique to model.
|
90 |
-
|
91 |
-
Returns:
|
92 |
-
model params.
|
93 |
-
"""
|
94 |
-
res = requests.post(self.host + "/params").json()
|
95 |
-
res["client_name"] = self.NAME
|
96 |
-
return res
|
97 |
-
|
98 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
99 |
-
"""
|
100 |
-
Format response to dict.
|
101 |
-
|
102 |
-
Args:
|
103 |
-
response: response
|
104 |
-
request: request
|
105 |
-
|
106 |
-
Return:
|
107 |
-
response as dict
|
108 |
-
"""
|
109 |
-
# Convert array to np.array
|
110 |
-
for choice in response["choices"]:
|
111 |
-
choice["array"] = self.to_numpy(np.array(choice["array"]))
|
112 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/dummy.py
DELETED
@@ -1,251 +0,0 @@
|
|
1 |
-
"""Dummy client."""
|
2 |
-
import hashlib
|
3 |
-
import logging
|
4 |
-
from typing import Any, Dict, List, Optional, Tuple
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import tiktoken
|
8 |
-
|
9 |
-
from manifest.clients.client import Client
|
10 |
-
from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request
|
11 |
-
from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages
|
12 |
-
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
-
|
15 |
-
|
16 |
-
class DummyClient(Client):
|
17 |
-
"""Dummy client."""
|
18 |
-
|
19 |
-
# User param -> (client param, default value)
|
20 |
-
PARAMS = {
|
21 |
-
"engine": ("model", "text-davinci-003"),
|
22 |
-
"temperature": ("temperature", 0.0),
|
23 |
-
"max_tokens": ("max_tokens", 10),
|
24 |
-
"n": ("n", 1),
|
25 |
-
"top_p": ("top_p", 1.0),
|
26 |
-
"top_k": ("best_of", 1),
|
27 |
-
"batch_size": ("batch_size", 20),
|
28 |
-
}
|
29 |
-
REQUEST_CLS = LMRequest
|
30 |
-
NAME = "dummy"
|
31 |
-
|
32 |
-
def connect(
|
33 |
-
self,
|
34 |
-
connection_str: Optional[str] = None,
|
35 |
-
client_args: Dict[str, Any] = {},
|
36 |
-
) -> None:
|
37 |
-
"""
|
38 |
-
Connect to dummpy server.
|
39 |
-
|
40 |
-
This is a dummy client that returns identity responses. Used for testing.
|
41 |
-
|
42 |
-
Args:
|
43 |
-
connection_str: connection string.
|
44 |
-
client_args: client arguments.
|
45 |
-
"""
|
46 |
-
# We tiktoken as it is faster than HF for tokenizing
|
47 |
-
# Use any model to create the tokenizer
|
48 |
-
self.encoder = tiktoken.get_encoding("cl100k_base")
|
49 |
-
for key in self.PARAMS:
|
50 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
51 |
-
|
52 |
-
def close(self) -> None:
|
53 |
-
"""Close the client."""
|
54 |
-
pass
|
55 |
-
|
56 |
-
def get_generation_url(self) -> str:
|
57 |
-
"""Get generation URL."""
|
58 |
-
return "dummy"
|
59 |
-
|
60 |
-
def supports_batch_inference(self) -> bool:
|
61 |
-
"""Return whether the client supports batch inference."""
|
62 |
-
return True
|
63 |
-
|
64 |
-
def supports_streaming_inference(self) -> bool:
|
65 |
-
"""Return whether the client supports streaming inference.
|
66 |
-
|
67 |
-
Override in child client class.
|
68 |
-
"""
|
69 |
-
return False
|
70 |
-
|
71 |
-
def get_generation_header(self) -> Dict[str, str]:
|
72 |
-
"""
|
73 |
-
Get generation header.
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
header.
|
77 |
-
"""
|
78 |
-
return {}
|
79 |
-
|
80 |
-
def get_model_params(self) -> Dict:
|
81 |
-
"""
|
82 |
-
Get model params.
|
83 |
-
|
84 |
-
By getting model params from the server, we can add to request
|
85 |
-
and make sure cache keys are unique to model.
|
86 |
-
|
87 |
-
Returns:
|
88 |
-
model params.
|
89 |
-
"""
|
90 |
-
return {"engine": "dummy", "model": getattr(self, "engine")}
|
91 |
-
|
92 |
-
def get_mock_output(
|
93 |
-
self, output_toks: int, is_completion: bool, seed: Optional[int] = None
|
94 |
-
) -> LMModelChoice:
|
95 |
-
"""Return mock model output by generating random tokens."""
|
96 |
-
np.random.seed(seed)
|
97 |
-
random_tokens = np.random.randint(
|
98 |
-
0, self.encoder.max_token_value + 1, output_toks
|
99 |
-
)
|
100 |
-
response = self.encoder.decode(random_tokens) # type: ignore
|
101 |
-
if is_completion:
|
102 |
-
np.random.seed(seed)
|
103 |
-
random_logprobs = np.random.uniform(
|
104 |
-
low=-2, high=-0.00001, size=output_toks
|
105 |
-
).tolist()
|
106 |
-
else:
|
107 |
-
# Return all Nones to mimic chat models
|
108 |
-
# OpenAI chat models do not return logprobs
|
109 |
-
random_logprobs = [None] * output_toks
|
110 |
-
return LMModelChoice(
|
111 |
-
text=response,
|
112 |
-
token_logprobs=random_logprobs,
|
113 |
-
tokens=random_tokens.tolist(),
|
114 |
-
)
|
115 |
-
|
116 |
-
def get_mock_choices(
|
117 |
-
self,
|
118 |
-
prompt_list: List[str],
|
119 |
-
request_params: Dict,
|
120 |
-
is_completion: bool,
|
121 |
-
) -> Tuple[List[LMModelChoice], List[Usage]]:
|
122 |
-
"""Get choices and usages of mock output."""
|
123 |
-
choices = []
|
124 |
-
usages = []
|
125 |
-
for prompt in prompt_list:
|
126 |
-
num_prompt_tokens = len(self.encoder.encode(prompt))
|
127 |
-
if request_params["temperature"] == 0:
|
128 |
-
# Get integer seed from hash of prompt
|
129 |
-
seed = (
|
130 |
-
int(hashlib.sha256(prompt.encode("utf-8")).hexdigest(), 16)
|
131 |
-
% 10**8
|
132 |
-
)
|
133 |
-
else:
|
134 |
-
# Get random seed
|
135 |
-
seed = None
|
136 |
-
for _ in range(int(request_params["n"])):
|
137 |
-
choice = self.get_mock_output(
|
138 |
-
request_params["max_tokens"], is_completion=is_completion, seed=seed
|
139 |
-
)
|
140 |
-
choices.append(choice)
|
141 |
-
usages.append(
|
142 |
-
Usage(
|
143 |
-
prompt_tokens=num_prompt_tokens,
|
144 |
-
completion_tokens=request_params["max_tokens"],
|
145 |
-
total_tokens=num_prompt_tokens + request_params["max_tokens"],
|
146 |
-
)
|
147 |
-
)
|
148 |
-
return choices, usages
|
149 |
-
|
150 |
-
def run_request(self, request: Request) -> Response:
|
151 |
-
"""
|
152 |
-
Get request string function.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
request: request.
|
156 |
-
|
157 |
-
Returns:
|
158 |
-
request function that takes no input.
|
159 |
-
request parameters as dict.
|
160 |
-
"""
|
161 |
-
if isinstance(request.prompt, list):
|
162 |
-
prompt_list = request.prompt
|
163 |
-
else:
|
164 |
-
prompt_list = [request.prompt]
|
165 |
-
request_params = request.to_dict(self.PARAMS)
|
166 |
-
|
167 |
-
choices, usages = self.get_mock_choices(
|
168 |
-
prompt_list, request_params, is_completion=True
|
169 |
-
)
|
170 |
-
return Response(
|
171 |
-
response=ModelChoices(choices=choices), # type: ignore
|
172 |
-
cached=False,
|
173 |
-
request=request,
|
174 |
-
usages=Usages(usages=usages),
|
175 |
-
response_type="text",
|
176 |
-
request_type=self.REQUEST_CLS,
|
177 |
-
)
|
178 |
-
|
179 |
-
async def arun_batch_request(
|
180 |
-
self, request: Request, verbose: bool = False
|
181 |
-
) -> Response:
|
182 |
-
"""
|
183 |
-
Get async request string function.
|
184 |
-
|
185 |
-
Args:
|
186 |
-
request: request.
|
187 |
-
|
188 |
-
Returns:
|
189 |
-
response.
|
190 |
-
"""
|
191 |
-
return self.run_request(request)
|
192 |
-
|
193 |
-
def run_chat_request(
|
194 |
-
self,
|
195 |
-
request: LMChatRequest,
|
196 |
-
) -> Response:
|
197 |
-
"""
|
198 |
-
Get the response from chat model.
|
199 |
-
|
200 |
-
Args:
|
201 |
-
request: request.
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
response.
|
205 |
-
"""
|
206 |
-
prompt_list = ["_".join(pmp["content"] for pmp in request.prompt)]
|
207 |
-
request_params = request.to_dict(self.PARAMS)
|
208 |
-
|
209 |
-
choices, usages = self.get_mock_choices(
|
210 |
-
prompt_list, request_params, is_completion=False
|
211 |
-
)
|
212 |
-
return Response(
|
213 |
-
response=ModelChoices(choices=choices), # type: ignore
|
214 |
-
cached=False,
|
215 |
-
request=request,
|
216 |
-
usages=Usages(usages=usages),
|
217 |
-
response_type="text",
|
218 |
-
request_type=LMChatRequest,
|
219 |
-
)
|
220 |
-
|
221 |
-
def run_score_prompt_request(
|
222 |
-
self,
|
223 |
-
request: LMScoreRequest,
|
224 |
-
) -> Response:
|
225 |
-
"""
|
226 |
-
Get the logit score of the prompt via a forward pass of the model.
|
227 |
-
|
228 |
-
Args:
|
229 |
-
request: request.
|
230 |
-
|
231 |
-
Returns:
|
232 |
-
request function that takes no input.
|
233 |
-
request parameters as dict.
|
234 |
-
"""
|
235 |
-
if isinstance(request.prompt, list):
|
236 |
-
prompt_list = request.prompt
|
237 |
-
else:
|
238 |
-
prompt_list = [request.prompt]
|
239 |
-
request_params = request.to_dict(self.PARAMS)
|
240 |
-
|
241 |
-
choices, usages = self.get_mock_choices(
|
242 |
-
prompt_list, request_params, is_completion=True
|
243 |
-
)
|
244 |
-
return Response(
|
245 |
-
response=ModelChoices(choices=choices), # type: ignore
|
246 |
-
cached=False,
|
247 |
-
request=request,
|
248 |
-
usages=Usages(usages=usages),
|
249 |
-
response_type="text",
|
250 |
-
request_type=LMScoreRequest,
|
251 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/google.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
"""Google client."""
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
import subprocess
|
5 |
-
from typing import Any, Dict, Optional, Type
|
6 |
-
|
7 |
-
from manifest.clients.client import Client
|
8 |
-
from manifest.request import LMRequest, Request
|
9 |
-
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
|
13 |
-
GOOGLE_ENGINES = {
|
14 |
-
"text-bison",
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
def get_project_id() -> Optional[str]:
|
19 |
-
"""Get project ID.
|
20 |
-
|
21 |
-
Run
|
22 |
-
`gcloud config get-value project`
|
23 |
-
"""
|
24 |
-
try:
|
25 |
-
project_id = subprocess.run(
|
26 |
-
["gcloud", "config", "get-value", "project"],
|
27 |
-
stdout=subprocess.PIPE,
|
28 |
-
stderr=subprocess.PIPE,
|
29 |
-
)
|
30 |
-
if project_id.stderr.decode("utf-8").strip():
|
31 |
-
return None
|
32 |
-
return project_id.stdout.decode("utf-8").strip()
|
33 |
-
except Exception:
|
34 |
-
return None
|
35 |
-
|
36 |
-
|
37 |
-
class GoogleClient(Client):
|
38 |
-
"""Google client."""
|
39 |
-
|
40 |
-
# User param -> (client param, default value)
|
41 |
-
PARAMS = {
|
42 |
-
"engine": ("model", "text-bison"),
|
43 |
-
"temperature": ("temperature", 1.0),
|
44 |
-
"max_tokens": ("maxOutputTokens", 10),
|
45 |
-
"top_p": ("topP", 1.0),
|
46 |
-
"top_k": ("topK", 1),
|
47 |
-
"batch_size": ("batch_size", 20),
|
48 |
-
}
|
49 |
-
REQUEST_CLS: Type[Request] = LMRequest
|
50 |
-
NAME = "google"
|
51 |
-
|
52 |
-
def connect(
|
53 |
-
self,
|
54 |
-
connection_str: Optional[str] = None,
|
55 |
-
client_args: Dict[str, Any] = {},
|
56 |
-
) -> None:
|
57 |
-
"""
|
58 |
-
Connect to the GoogleVertex API.
|
59 |
-
|
60 |
-
connection_str is passed as default GOOGLE_API_KEY if variable not set.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
connection_str: connection string.
|
64 |
-
client_args: client arguments.
|
65 |
-
"""
|
66 |
-
connection_parts = connection_str.split("::")
|
67 |
-
if len(connection_parts) == 1:
|
68 |
-
self.api_key = connection_parts[0]
|
69 |
-
self.project_id = None
|
70 |
-
elif len(connection_parts) == 2:
|
71 |
-
self.api_key, self.project_id = connection_parts
|
72 |
-
else:
|
73 |
-
raise ValueError(
|
74 |
-
"Invalid connection string. "
|
75 |
-
"Must be either API_KEY or API_KEY::PROJECT_ID"
|
76 |
-
)
|
77 |
-
self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
|
78 |
-
if self.api_key is None:
|
79 |
-
raise ValueError(
|
80 |
-
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
|
81 |
-
"variable or pass through `client_connection`. This can be "
|
82 |
-
"found by running `gcloud auth print-access-token`"
|
83 |
-
)
|
84 |
-
self.project_id = (
|
85 |
-
self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
|
86 |
-
)
|
87 |
-
if self.project_id is None:
|
88 |
-
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
|
89 |
-
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
|
90 |
-
|
91 |
-
for key in self.PARAMS:
|
92 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
93 |
-
if getattr(self, "engine") not in GOOGLE_ENGINES:
|
94 |
-
raise ValueError(
|
95 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
|
96 |
-
)
|
97 |
-
|
98 |
-
def close(self) -> None:
|
99 |
-
"""Close the client."""
|
100 |
-
pass
|
101 |
-
|
102 |
-
def get_generation_url(self) -> str:
|
103 |
-
"""Get generation URL."""
|
104 |
-
model = getattr(self, "engine")
|
105 |
-
return self.host + f"/{model}:predict"
|
106 |
-
|
107 |
-
def get_generation_header(self) -> Dict[str, str]:
|
108 |
-
"""
|
109 |
-
Get generation header.
|
110 |
-
|
111 |
-
Returns:
|
112 |
-
header.
|
113 |
-
"""
|
114 |
-
return {"Authorization": f"Bearer {self.api_key}"}
|
115 |
-
|
116 |
-
def supports_batch_inference(self) -> bool:
|
117 |
-
"""Return whether the client supports batch inference."""
|
118 |
-
return True
|
119 |
-
|
120 |
-
def supports_streaming_inference(self) -> bool:
|
121 |
-
"""Return whether the client supports streaming inference.
|
122 |
-
|
123 |
-
Override in child client class.
|
124 |
-
"""
|
125 |
-
return False
|
126 |
-
|
127 |
-
def get_model_params(self) -> Dict:
|
128 |
-
"""
|
129 |
-
Get model params.
|
130 |
-
|
131 |
-
By getting model params from the server, we can add to request
|
132 |
-
and make sure cache keys are unique to model.
|
133 |
-
|
134 |
-
Returns:
|
135 |
-
model params.
|
136 |
-
"""
|
137 |
-
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
138 |
-
|
139 |
-
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
140 |
-
"""
|
141 |
-
Preprocess request params.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
request: request params.
|
145 |
-
|
146 |
-
Returns:
|
147 |
-
request params.
|
148 |
-
"""
|
149 |
-
# Refortmat the request params for google
|
150 |
-
prompt = request.pop("prompt")
|
151 |
-
if isinstance(prompt, str):
|
152 |
-
prompt_list = [prompt]
|
153 |
-
else:
|
154 |
-
prompt_list = prompt
|
155 |
-
google_request = {
|
156 |
-
"instances": [{"prompt": prompt} for prompt in prompt_list],
|
157 |
-
"parameters": request,
|
158 |
-
}
|
159 |
-
return super().preprocess_request_params(google_request)
|
160 |
-
|
161 |
-
def postprocess_response(
|
162 |
-
self, response: Dict[str, Any], request: Dict[str, Any]
|
163 |
-
) -> Dict[str, Any]:
|
164 |
-
"""
|
165 |
-
Validate response as dict.
|
166 |
-
|
167 |
-
Assumes response is dict
|
168 |
-
{
|
169 |
-
"predictions": [
|
170 |
-
{
|
171 |
-
"safetyAttributes": {
|
172 |
-
"categories": ["Violent", "Sexual"],
|
173 |
-
"blocked": false,
|
174 |
-
"scores": [0.1, 0.1]
|
175 |
-
},
|
176 |
-
"content": "SELECT * FROM "WWW";"
|
177 |
-
}
|
178 |
-
]
|
179 |
-
}
|
180 |
-
|
181 |
-
Args:
|
182 |
-
response: response
|
183 |
-
request: request
|
184 |
-
|
185 |
-
Return:
|
186 |
-
response as dict
|
187 |
-
"""
|
188 |
-
google_predictions = response.pop("predictions")
|
189 |
-
new_response = {
|
190 |
-
"choices": [
|
191 |
-
{
|
192 |
-
"text": prediction["content"],
|
193 |
-
}
|
194 |
-
for prediction in google_predictions
|
195 |
-
]
|
196 |
-
}
|
197 |
-
return super().postprocess_response(new_response, request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/google_chat.py
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
"""Google client."""
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
from typing import Any, Dict, Optional, Type
|
6 |
-
|
7 |
-
from manifest.clients.google import GoogleClient, get_project_id
|
8 |
-
from manifest.request import LMRequest, Request
|
9 |
-
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
|
13 |
-
GOOGLE_ENGINES = {
|
14 |
-
"chat-bison",
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
class GoogleChatClient(GoogleClient):
|
19 |
-
"""GoogleChat client."""
|
20 |
-
|
21 |
-
# User param -> (client param, default value)
|
22 |
-
PARAMS = {
|
23 |
-
"engine": ("model", "chat-bison"),
|
24 |
-
"temperature": ("temperature", 1.0),
|
25 |
-
"max_tokens": ("maxOutputTokens", 10),
|
26 |
-
"top_p": ("topP", 1.0),
|
27 |
-
"top_k": ("topK", 1),
|
28 |
-
"batch_size": ("batch_size", 20),
|
29 |
-
}
|
30 |
-
REQUEST_CLS: Type[Request] = LMRequest
|
31 |
-
NAME = "googlechat"
|
32 |
-
IS_CHAT = True
|
33 |
-
|
34 |
-
def connect(
|
35 |
-
self,
|
36 |
-
connection_str: Optional[str] = None,
|
37 |
-
client_args: Dict[str, Any] = {},
|
38 |
-
) -> None:
|
39 |
-
"""
|
40 |
-
Connect to the GoogleVertex API.
|
41 |
-
|
42 |
-
connection_str is passed as default GOOGLE_API_KEY if variable not set.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
connection_str: connection string.
|
46 |
-
client_args: client arguments.
|
47 |
-
"""
|
48 |
-
connection_parts = connection_str.split("::")
|
49 |
-
if len(connection_parts) == 1:
|
50 |
-
self.api_key = connection_parts[0]
|
51 |
-
elif len(connection_parts) == 2:
|
52 |
-
self.api_key, self.project_id = connection_parts
|
53 |
-
else:
|
54 |
-
raise ValueError(
|
55 |
-
"Invalid connection string. "
|
56 |
-
"Must be either API_KEY or API_KEY::PROJECT_ID"
|
57 |
-
)
|
58 |
-
self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
|
59 |
-
if self.api_key is None:
|
60 |
-
raise ValueError(
|
61 |
-
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
|
62 |
-
"variable or pass through `client_connection`. This can be "
|
63 |
-
"found by running `gcloud auth print-access-token`"
|
64 |
-
)
|
65 |
-
self.project_id = (
|
66 |
-
self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
|
67 |
-
)
|
68 |
-
if self.project_id is None:
|
69 |
-
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
|
70 |
-
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
|
71 |
-
|
72 |
-
for key in self.PARAMS:
|
73 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
74 |
-
if getattr(self, "engine") not in GOOGLE_ENGINES:
|
75 |
-
raise ValueError(
|
76 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
|
77 |
-
)
|
78 |
-
|
79 |
-
def supports_batch_inference(self) -> bool:
|
80 |
-
"""Return whether the client supports batch inference."""
|
81 |
-
return False
|
82 |
-
|
83 |
-
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
84 |
-
"""
|
85 |
-
Preprocess request params.
|
86 |
-
|
87 |
-
Args:
|
88 |
-
request: request params.
|
89 |
-
|
90 |
-
Returns:
|
91 |
-
request params.
|
92 |
-
"""
|
93 |
-
# Format for chat model
|
94 |
-
request = copy.deepcopy(request)
|
95 |
-
prompt = request.pop("prompt")
|
96 |
-
if isinstance(prompt, str):
|
97 |
-
messages = [{"author": "user", "content": prompt}]
|
98 |
-
elif isinstance(prompt, list) and isinstance(prompt[0], str):
|
99 |
-
prompt_list = prompt
|
100 |
-
messages = [{"author": "user", "content": prompt} for prompt in prompt_list]
|
101 |
-
elif isinstance(prompt, list) and isinstance(prompt[0], dict):
|
102 |
-
for pmt_dict in prompt:
|
103 |
-
if "author" not in pmt_dict or "content" not in pmt_dict:
|
104 |
-
raise ValueError(
|
105 |
-
"Prompt must be list of dicts with 'author' and 'content' "
|
106 |
-
f"keys. Got {prompt}."
|
107 |
-
)
|
108 |
-
messages = prompt
|
109 |
-
else:
|
110 |
-
raise ValueError(
|
111 |
-
"Prompt must be string, list of strings, or list of dicts."
|
112 |
-
f"Got {prompt}"
|
113 |
-
)
|
114 |
-
new_request = {
|
115 |
-
"instances": [{"messages": messages}],
|
116 |
-
"parameters": request,
|
117 |
-
}
|
118 |
-
return super(GoogleClient, self).preprocess_request_params(new_request)
|
119 |
-
|
120 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
121 |
-
"""
|
122 |
-
Validate response as dict.
|
123 |
-
|
124 |
-
Assumes response is dict
|
125 |
-
{
|
126 |
-
"candidates": [
|
127 |
-
{
|
128 |
-
"safetyAttributes": {
|
129 |
-
"categories": ["Violent", "Sexual"],
|
130 |
-
"blocked": false,
|
131 |
-
"scores": [0.1, 0.1]
|
132 |
-
},
|
133 |
-
"author": "1",
|
134 |
-
"content": "SELECT * FROM "WWW";"
|
135 |
-
}
|
136 |
-
]
|
137 |
-
}
|
138 |
-
|
139 |
-
Args:
|
140 |
-
response: response
|
141 |
-
request: request
|
142 |
-
|
143 |
-
Return:
|
144 |
-
response as dict
|
145 |
-
"""
|
146 |
-
google_predictions = response.pop("predictions")
|
147 |
-
new_response = {
|
148 |
-
"choices": [
|
149 |
-
{
|
150 |
-
"text": prediction["candidates"][0]["content"],
|
151 |
-
}
|
152 |
-
for prediction in google_predictions
|
153 |
-
]
|
154 |
-
}
|
155 |
-
return super(GoogleClient, self).postprocess_response(new_response, request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/huggingface.py
DELETED
@@ -1,137 +0,0 @@
|
|
1 |
-
"""Hugging Face client."""
|
2 |
-
import logging
|
3 |
-
from functools import lru_cache
|
4 |
-
from typing import Any, Dict, Optional
|
5 |
-
|
6 |
-
import requests
|
7 |
-
|
8 |
-
from manifest.clients.client import Client
|
9 |
-
from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, LMScoreRequest
|
10 |
-
from manifest.response import LMModelChoice, ModelChoices, Response
|
11 |
-
|
12 |
-
logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
-
|
15 |
-
class HuggingFaceClient(Client):
|
16 |
-
"""HuggingFace client."""
|
17 |
-
|
18 |
-
# User param -> (client param, default value)
|
19 |
-
PARAMS = {
|
20 |
-
"temperature": ("temperature", 0.1),
|
21 |
-
"max_tokens": ("max_tokens", 10),
|
22 |
-
"n": ("n", 1),
|
23 |
-
"top_p": ("top_p", 1.0),
|
24 |
-
"top_k": ("top_k", 50),
|
25 |
-
"repetition_penalty": ("repetition_penalty", 1.0),
|
26 |
-
"do_sample": ("do_sample", True),
|
27 |
-
}
|
28 |
-
REQUEST_CLS = LMRequest
|
29 |
-
NAME = "huggingface"
|
30 |
-
|
31 |
-
def connect(
|
32 |
-
self,
|
33 |
-
connection_str: Optional[str] = None,
|
34 |
-
client_args: Dict[str, Any] = {},
|
35 |
-
) -> None:
|
36 |
-
"""
|
37 |
-
Connect to the HuggingFace url.
|
38 |
-
|
39 |
-
Arsg:
|
40 |
-
connection_str: connection string.
|
41 |
-
client_args: client arguments.
|
42 |
-
"""
|
43 |
-
if not connection_str:
|
44 |
-
raise ValueError("Must provide connection string")
|
45 |
-
self.host = connection_str.rstrip("/")
|
46 |
-
for key in self.PARAMS:
|
47 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
48 |
-
|
49 |
-
def close(self) -> None:
|
50 |
-
"""Close the client."""
|
51 |
-
pass
|
52 |
-
|
53 |
-
def get_generation_url(self) -> str:
|
54 |
-
"""Get generation URL."""
|
55 |
-
return self.host + "/completions"
|
56 |
-
|
57 |
-
def get_generation_header(self) -> Dict[str, str]:
|
58 |
-
"""
|
59 |
-
Get generation header.
|
60 |
-
|
61 |
-
Returns:
|
62 |
-
header.
|
63 |
-
"""
|
64 |
-
return {}
|
65 |
-
|
66 |
-
def supports_batch_inference(self) -> bool:
|
67 |
-
"""Return whether the client supports batch inference."""
|
68 |
-
return True
|
69 |
-
|
70 |
-
def supports_streaming_inference(self) -> bool:
|
71 |
-
"""Return whether the client supports streaming inference.
|
72 |
-
|
73 |
-
Override in child client class.
|
74 |
-
"""
|
75 |
-
return False
|
76 |
-
|
77 |
-
@lru_cache(maxsize=1)
|
78 |
-
def get_model_params(self) -> Dict:
|
79 |
-
"""
|
80 |
-
Get model params.
|
81 |
-
|
82 |
-
By getting model params from the server, we can add to request
|
83 |
-
and make sure cache keys are unique to model.
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
model params.
|
87 |
-
"""
|
88 |
-
res = requests.post(self.host + "/params").json()
|
89 |
-
res["client_name"] = self.NAME
|
90 |
-
return res
|
91 |
-
|
92 |
-
def run_score_prompt_request(
|
93 |
-
self,
|
94 |
-
request: LMScoreRequest,
|
95 |
-
) -> Response:
|
96 |
-
"""
|
97 |
-
Get the logit score of the prompt via a forward pass of the model.
|
98 |
-
|
99 |
-
Args:
|
100 |
-
request: request.
|
101 |
-
|
102 |
-
Returns:
|
103 |
-
request function that takes no input.
|
104 |
-
request parameters as dict.
|
105 |
-
"""
|
106 |
-
request_params = self._get_request_params(request)
|
107 |
-
retry_timeout = request_params.pop("client_timeout")
|
108 |
-
for key in DEFAULT_REQUEST_KEYS:
|
109 |
-
request_params.pop(key, None)
|
110 |
-
# Do not add params like we do with request as the model isn't sampling
|
111 |
-
request_params = {"prompt": request.prompt}
|
112 |
-
|
113 |
-
post_str = self.host + "/score_sequence"
|
114 |
-
try:
|
115 |
-
res = requests.post(
|
116 |
-
post_str,
|
117 |
-
json=request_params,
|
118 |
-
timeout=retry_timeout,
|
119 |
-
)
|
120 |
-
res.raise_for_status()
|
121 |
-
except requests.Timeout as e:
|
122 |
-
logger.error("HF request timed out. Increase client_timeout.")
|
123 |
-
raise e
|
124 |
-
except requests.exceptions.HTTPError as e:
|
125 |
-
logger.error(res.text)
|
126 |
-
raise e
|
127 |
-
response_dict = res.json()
|
128 |
-
return Response(
|
129 |
-
response=ModelChoices(
|
130 |
-
choices=[LMModelChoice(**choice) for choice in response_dict["choices"]]
|
131 |
-
),
|
132 |
-
cached=False,
|
133 |
-
request=request,
|
134 |
-
usages=None,
|
135 |
-
response_type="text",
|
136 |
-
request_type=LMScoreRequest,
|
137 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/huggingface_embedding.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
"""Hugging Face client."""
|
2 |
-
import logging
|
3 |
-
from functools import lru_cache
|
4 |
-
from typing import Any, Dict, Optional, Tuple
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import requests
|
8 |
-
|
9 |
-
from manifest.clients.client import Client
|
10 |
-
from manifest.request import EmbeddingRequest
|
11 |
-
|
12 |
-
logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
-
|
15 |
-
class HuggingFaceEmbeddingClient(Client):
|
16 |
-
"""HuggingFaceEmbedding client."""
|
17 |
-
|
18 |
-
# User param -> (client param, default value)
|
19 |
-
PARAMS: Dict[str, Tuple[str, Any]] = {}
|
20 |
-
REQUEST_CLS = EmbeddingRequest
|
21 |
-
NAME = "huggingfaceembedding"
|
22 |
-
|
23 |
-
def connect(
|
24 |
-
self,
|
25 |
-
connection_str: Optional[str] = None,
|
26 |
-
client_args: Dict[str, Any] = {},
|
27 |
-
) -> None:
|
28 |
-
"""
|
29 |
-
Connect to the HuggingFace url.
|
30 |
-
|
31 |
-
Arsg:
|
32 |
-
connection_str: connection string.
|
33 |
-
client_args: client arguments.
|
34 |
-
"""
|
35 |
-
if not connection_str:
|
36 |
-
raise ValueError("Must provide connection string")
|
37 |
-
self.host = connection_str.rstrip("/")
|
38 |
-
for key in self.PARAMS:
|
39 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
40 |
-
|
41 |
-
def close(self) -> None:
|
42 |
-
"""Close the client."""
|
43 |
-
pass
|
44 |
-
|
45 |
-
def get_generation_url(self) -> str:
|
46 |
-
"""Get generation URL."""
|
47 |
-
return self.host + "/embed"
|
48 |
-
|
49 |
-
def get_generation_header(self) -> Dict[str, str]:
|
50 |
-
"""
|
51 |
-
Get generation header.
|
52 |
-
|
53 |
-
Returns:
|
54 |
-
header.
|
55 |
-
"""
|
56 |
-
return {}
|
57 |
-
|
58 |
-
def supports_batch_inference(self) -> bool:
|
59 |
-
"""Return whether the client supports batch inference."""
|
60 |
-
return True
|
61 |
-
|
62 |
-
def supports_streaming_inference(self) -> bool:
|
63 |
-
"""Return whether the client supports streaming inference.
|
64 |
-
|
65 |
-
Override in child client class.
|
66 |
-
"""
|
67 |
-
return False
|
68 |
-
|
69 |
-
@lru_cache(maxsize=1)
|
70 |
-
def get_model_params(self) -> Dict:
|
71 |
-
"""
|
72 |
-
Get model params.
|
73 |
-
|
74 |
-
By getting model params from the server, we can add to request
|
75 |
-
and make sure cache keys are unique to model.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
model params.
|
79 |
-
"""
|
80 |
-
res = requests.post(self.host + "/params").json()
|
81 |
-
res["client_name"] = self.NAME
|
82 |
-
return res
|
83 |
-
|
84 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
85 |
-
"""
|
86 |
-
Format response to dict.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
response: response
|
90 |
-
request: request
|
91 |
-
|
92 |
-
Return:
|
93 |
-
response as dict
|
94 |
-
"""
|
95 |
-
# Convert array to np.array
|
96 |
-
for choice in response["choices"]:
|
97 |
-
choice["array"] = np.array(choice["array"])
|
98 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duckdb-nsql/manifest/manifest/clients/openai.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
"""OpenAI client."""
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from typing import Any, Dict, List, Optional, Type
|
5 |
-
|
6 |
-
import tiktoken
|
7 |
-
|
8 |
-
from manifest.clients.client import Client
|
9 |
-
from manifest.request import LMRequest, Request
|
10 |
-
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
OPENAI_ENGINES = {
|
14 |
-
"gpt-3.5-turbo-instruct",
|
15 |
-
"text-davinci-003",
|
16 |
-
"text-davinci-002",
|
17 |
-
"text-davinci-001",
|
18 |
-
"davinci",
|
19 |
-
"curie",
|
20 |
-
"ada",
|
21 |
-
"babbage",
|
22 |
-
"text-curie-001",
|
23 |
-
"text-babbage-001",
|
24 |
-
"text-ada-001",
|
25 |
-
"code-davinci-002",
|
26 |
-
"code-cushman-001",
|
27 |
-
}
|
28 |
-
|
29 |
-
|
30 |
-
class OpenAIClient(Client):
|
31 |
-
"""OpenAI client."""
|
32 |
-
|
33 |
-
# User param -> (client param, default value)
|
34 |
-
PARAMS = {
|
35 |
-
"engine": ("model", "text-davinci-003"),
|
36 |
-
"temperature": ("temperature", 1.0),
|
37 |
-
"max_tokens": ("max_tokens", 10),
|
38 |
-
"n": ("n", 1),
|
39 |
-
"top_p": ("top_p", 1.0),
|
40 |
-
"top_k": ("best_of", 1),
|
41 |
-
"logprobs": ("logprobs", None),
|
42 |
-
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
|
43 |
-
"presence_penalty": ("presence_penalty", 0.0),
|
44 |
-
"frequency_penalty": ("frequency_penalty", 0.0),
|
45 |
-
"batch_size": ("batch_size", 20),
|
46 |
-
}
|
47 |
-
REQUEST_CLS: Type[Request] = LMRequest
|
48 |
-
NAME = "openai"
|
49 |
-
|
50 |
-
def connect(
|
51 |
-
self,
|
52 |
-
connection_str: Optional[str] = None,
|
53 |
-
client_args: Dict[str, Any] = {},
|
54 |
-
) -> None:
|
55 |
-
"""
|
56 |
-
Connect to the OpenAI server.
|
57 |
-
|
58 |
-
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
connection_str: connection string.
|
62 |
-
client_args: client arguments.
|
63 |
-
"""
|
64 |
-
self.api_key = connection_str or os.environ.get("OPENAI_API_KEY")
|
65 |
-
if self.api_key is None:
|
66 |
-
raise ValueError(
|
67 |
-
"OpenAI API key not set. Set OPENAI_API_KEY environment "
|
68 |
-
"variable or pass through `client_connection`."
|
69 |
-
)
|
70 |
-
self.host = "https://api.openai.com/v1"
|
71 |
-
for key in self.PARAMS:
|
72 |
-
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
73 |
-
if getattr(self, "engine") not in OPENAI_ENGINES:
|
74 |
-
raise ValueError(
|
75 |
-
f"Invalid engine {getattr(self, 'engine')}. Must be {OPENAI_ENGINES}."
|
76 |
-
)
|
77 |
-
|
78 |
-
def close(self) -> None:
|
79 |
-
"""Close the client."""
|
80 |
-
pass
|
81 |
-
|
82 |
-
def get_generation_url(self) -> str:
|
83 |
-
"""Get generation URL."""
|
84 |
-
return self.host + "/completions"
|
85 |
-
|
86 |
-
def get_generation_header(self) -> Dict[str, str]:
|
87 |
-
"""
|
88 |
-
Get generation header.
|
89 |
-
|
90 |
-
Returns:
|
91 |
-
header.
|
92 |
-
"""
|
93 |
-
return {"Authorization": f"Bearer {self.api_key}"}
|
94 |
-
|
95 |
-
def supports_batch_inference(self) -> bool:
|
96 |
-
"""Return whether the client supports batch inference."""
|
97 |
-
return True
|
98 |
-
|
99 |
-
def supports_streaming_inference(self) -> bool:
|
100 |
-
"""Return whether the client supports streaming inference.
|
101 |
-
|
102 |
-
Override in child client class.
|
103 |
-
"""
|
104 |
-
return True
|
105 |
-
|
106 |
-
def get_model_params(self) -> Dict:
|
107 |
-
"""
|
108 |
-
Get model params.
|
109 |
-
|
110 |
-
By getting model params from the server, we can add to request
|
111 |
-
and make sure cache keys are unique to model.
|
112 |
-
|
113 |
-
Returns:
|
114 |
-
model params.
|
115 |
-
"""
|
116 |
-
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
|
117 |
-
|
118 |
-
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
|
119 |
-
"""
|
120 |
-
Validate response as dict.
|
121 |
-
|
122 |
-
Args:
|
123 |
-
response: response
|
124 |
-
request: request
|
125 |
-
|
126 |
-
Return:
|
127 |
-
response as dict
|
128 |
-
"""
|
129 |
-
validated_response = super().postprocess_response(response, request)
|
130 |
-
# Handle logprobs
|
131 |
-
for choice in validated_response["choices"]:
|
132 |
-
if "logprobs" in choice:
|
133 |
-
logprobs = choice.pop("logprobs")
|
134 |
-
if logprobs and "token_logprobs" in logprobs:
|
135 |
-
choice["token_logprobs"] = logprobs["token_logprobs"]
|
136 |
-
choice["tokens"] = logprobs["tokens"]
|
137 |
-
return validated_response
|
138 |
-
|
139 |
-
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
|
140 |
-
"""Split usage into list of usages for each prompt."""
|
141 |
-
try:
|
142 |
-
encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
|
143 |
-
except Exception:
|
144 |
-
return []
|
145 |
-
prompt = request["prompt"]
|
146 |
-
# If n > 1 and prompt is a string, we need to split it into a list
|
147 |
-
if isinstance(prompt, str):
|
148 |
-
prompts = [prompt] * len(choices)
|
149 |
-
else:
|
150 |
-
prompts = prompt
|
151 |
-
assert len(prompts) == len(choices)
|
152 |
-
usages = []
|
153 |
-
for pmt, chc in zip(prompts, choices):
|
154 |
-
pmt_tokens = len(encoding.encode(pmt))
|
155 |
-
chc_tokens = len(encoding.encode(chc["text"])) # type: ignore
|
156 |
-
usage = {
|
157 |
-
"prompt_tokens": pmt_tokens,
|
158 |
-
"completion_tokens": chc_tokens,
|
159 |
-
"total_tokens": pmt_tokens + chc_tokens,
|
160 |
-
}
|
161 |
-
usages.append(usage)
|
162 |
-
return usages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|