tdoehmen commited on
Commit
c635df2
1 Parent(s): 2dc631c

added hf inference api

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. duckdb-nsql/eval/get_manifest.py +1 -1
  2. duckdb-nsql/eval/predict.py +2 -2
  3. duckdb-nsql/manifest/.flake8 +0 -11
  4. duckdb-nsql/manifest/.pre-commit-config.yaml +0 -23
  5. duckdb-nsql/manifest/CHANGELOG.rst +0 -93
  6. duckdb-nsql/manifest/LICENSE +0 -201
  7. duckdb-nsql/manifest/Makefile +0 -27
  8. duckdb-nsql/manifest/README.md +0 -304
  9. duckdb-nsql/manifest/examples/langchain_chatgpt.ipynb +0 -455
  10. duckdb-nsql/manifest/examples/manifest_async.py +0 -27
  11. duckdb-nsql/manifest/examples/manifest_azure.ipynb +0 -149
  12. duckdb-nsql/manifest/examples/manifest_chatgpt.ipynb +0 -101
  13. duckdb-nsql/manifest/examples/manifest_connection_pool.ipynb +0 -208
  14. duckdb-nsql/manifest/examples/manifest_diffusers.ipynb +0 -0
  15. duckdb-nsql/manifest/examples/manifest_embedding.ipynb +0 -156
  16. duckdb-nsql/manifest/examples/manifest_google.ipynb +0 -117
  17. duckdb-nsql/manifest/examples/manifest_openrouter.ipynb +0 -108
  18. duckdb-nsql/manifest/examples/manifest_streaming.ipynb +0 -105
  19. duckdb-nsql/manifest/examples/manifest_together.ipynb +0 -106
  20. duckdb-nsql/manifest/manifest/__init__.py +0 -6
  21. duckdb-nsql/manifest/manifest/api/__init__.py +0 -1
  22. duckdb-nsql/manifest/manifest/api/app.py +0 -301
  23. duckdb-nsql/manifest/manifest/api/models/__init__.py +0 -1
  24. duckdb-nsql/manifest/manifest/api/models/diffuser.py +0 -123
  25. duckdb-nsql/manifest/manifest/api/models/huggingface.py +0 -671
  26. duckdb-nsql/manifest/manifest/api/models/model.py +0 -91
  27. duckdb-nsql/manifest/manifest/api/models/sentence_transformer.py +0 -113
  28. duckdb-nsql/manifest/manifest/api/response.py +0 -55
  29. duckdb-nsql/manifest/manifest/caches/__init__.py +0 -1
  30. duckdb-nsql/manifest/manifest/caches/array_cache.py +0 -116
  31. duckdb-nsql/manifest/manifest/caches/cache.py +0 -135
  32. duckdb-nsql/manifest/manifest/caches/noop.py +0 -47
  33. duckdb-nsql/manifest/manifest/caches/postgres.py +0 -131
  34. duckdb-nsql/manifest/manifest/caches/redis.py +0 -64
  35. duckdb-nsql/manifest/manifest/caches/serializers.py +0 -204
  36. duckdb-nsql/manifest/manifest/caches/sqlite.py +0 -65
  37. duckdb-nsql/manifest/manifest/clients/__init__.py +0 -1
  38. duckdb-nsql/manifest/manifest/clients/ai21.py +0 -125
  39. duckdb-nsql/manifest/manifest/clients/azureendpoint.py +0 -139
  40. duckdb-nsql/manifest/manifest/clients/azureopenai.py +0 -113
  41. duckdb-nsql/manifest/manifest/clients/azureopenai_chat.py +0 -116
  42. duckdb-nsql/manifest/manifest/clients/client.py +0 -699
  43. duckdb-nsql/manifest/manifest/clients/cohere.py +0 -125
  44. duckdb-nsql/manifest/manifest/clients/diffuser.py +0 -112
  45. duckdb-nsql/manifest/manifest/clients/dummy.py +0 -251
  46. duckdb-nsql/manifest/manifest/clients/google.py +0 -197
  47. duckdb-nsql/manifest/manifest/clients/google_chat.py +0 -155
  48. duckdb-nsql/manifest/manifest/clients/huggingface.py +0 -137
  49. duckdb-nsql/manifest/manifest/clients/huggingface_embedding.py +0 -98
  50. duckdb-nsql/manifest/manifest/clients/openai.py +0 -162
duckdb-nsql/eval/get_manifest.py CHANGED
@@ -9,7 +9,7 @@ def get_manifest(
9
  manifest_engine: str,
10
  ) -> Manifest:
11
  """Get manifest engine."""
12
- if manifest_client in {"openai", "openaichat", "openai_mock", "openrouter", "azureendpoint"}:
13
  manifest = Manifest(
14
  client_name=manifest_client,
15
  engine=manifest_engine,
 
9
  manifest_engine: str,
10
  ) -> Manifest:
11
  """Get manifest engine."""
12
+ if manifest_client in {"openai", "openaichat", "openai_mock", "openrouter", "azureendpoint", "inference_api"}:
13
  manifest = Manifest(
14
  client_name=manifest_client,
15
  engine=manifest_engine,
duckdb-nsql/eval/predict.py CHANGED
@@ -213,7 +213,7 @@ def predict(
213
  console.print(f"Running with {manifest_params} manifest.")
214
  model_name = manifest_params.get("engine", manifest_params["model_name"])
215
 
216
- if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint"}:
217
  tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
218
  else:
219
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -234,7 +234,7 @@ def predict(
234
  middleix = manifest_engine
235
  elif manifest_client in {"huggingface", "ray"}:
236
  middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-")
237
- elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint"}:
238
  middleix = manifest_engine.split("/")[-1]
239
  else:
240
  raise ValueError(f"Unknown manifest client {manifest_client}")
 
213
  console.print(f"Running with {manifest_params} manifest.")
214
  model_name = manifest_params.get("engine", manifest_params["model_name"])
215
 
216
+ if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint", "inference_api"}:
217
  tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
218
  else:
219
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
234
  middleix = manifest_engine
235
  elif manifest_client in {"huggingface", "ray"}:
236
  middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-")
237
+ elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint", "inference_api"}:
238
  middleix = manifest_engine.split("/")[-1]
239
  else:
240
  raise ValueError(f"Unknown manifest client {manifest_client}")
duckdb-nsql/manifest/.flake8 DELETED
@@ -1,11 +0,0 @@
1
- # This is our code-style check. We currently allow the following exceptions:
2
- # - E731: do not assign a lambda expression, use a def
3
- # - E402: module level import not at top of file
4
- # - W503: line break before binary operator
5
- # - E203: whitespace before :
6
-
7
- [flake8]
8
- exclude = .git
9
- max-line-length = 88
10
- ignore = E731, E402, W503, E203, PAI100, PAI101, PAI201, PAI202, PAI203
11
- per-file-ignores = __init__.py:F401, version.py:D100
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/.pre-commit-config.yaml DELETED
@@ -1,23 +0,0 @@
1
- repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v3.2.0
4
- hooks:
5
- - id: trailing-whitespace
6
- - id: end-of-file-fixer
7
- - id: check-yaml
8
- - id: check-toml
9
- - id: check-merge-conflict
10
- - id: check-added-large-files
11
- - repo: https://github.com/timothycrosley/isort
12
- rev: 5.13.2
13
- hooks:
14
- - id: isort
15
- - repo: https://github.com/psf/black
16
- rev: 22.3.0
17
- hooks:
18
- - id: black
19
- language_version: python3
20
- - repo: https://github.com/PyCQA/flake8
21
- rev: 6.0.0
22
- hooks:
23
- - id: flake8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/CHANGELOG.rst DELETED
@@ -1,93 +0,0 @@
1
- 0.1.10 - Unreleased
2
- ---------------------
3
-
4
- 0.1.9 - 2024-01-22
5
- ---------------------
6
- Fixed
7
- ^^^^^
8
- * Added trust code params HF models
9
- * Added LRU cache to HF model param calls to avoid extra calls
10
- * Fixed pydantic type issue HF model return
11
- * Support for Python 3.10-3.11
12
-
13
- 0.1.8 - 2023-05-22
14
- ---------------------
15
- Added
16
- ^^^^^
17
- * Azure model support (completion and chat)
18
- * Google Vertex API model support (completion and chat)
19
- * Streaming responses for LM Completions (set stream=True)
20
-
21
- Fixed
22
- ^^^^^
23
- * `run` with batches now acts the same as async run except not async. We will batch requests into appropriate batchs sizes.
24
- * Refactored client so unified preprocess and postprocess of requests and responses to better support model variants in request/response format.
25
-
26
- 0.1.7 - 2023-05-17
27
- ---------------------
28
- Fixed
29
- ^^^^^
30
- * `_run_chat` fixed bug where not passing in kwargs
31
-
32
- 0.1.6 - 2023-05-16
33
- ---------------------
34
- Fixed
35
- ^^^^^
36
- * Unified `run` and `run_chat` methods so it's just `run` now.
37
- * LLama HF models for eval
38
-
39
- 0.1.5 - 2023-05-03
40
- ---------------------
41
- Added
42
- ^^^^^
43
- * Added chat input for chat models.
44
-
45
- 0.1.4 - 2023-04-24
46
- ---------------------
47
- Added
48
- ^^^^^
49
- * Connection pools to swap between clients
50
- * Chunksize param for async runs
51
-
52
- Fixed
53
- ^^^^^
54
- * Determine cache and response by request type, not client name
55
- * Refactor Response to use Pydantic types for Request and Response
56
-
57
- 0.1.1
58
- ---------------------
59
- Added
60
- ^^^^^
61
- * Async support in arun_batch
62
-
63
- Fixed
64
- ^^^^^
65
- * Batched runs now caches individual items
66
- * Score prompt does not truncate outside token
67
-
68
- Removed
69
- ^^^^^
70
- * Deprecated chatGPT in favor of openaichat which uses OpenAI completions
71
- * Deprecated Sessions
72
-
73
- 0.1.0 - 2022-01-31
74
- ---------------------
75
- Added
76
- ^^^^^
77
- * Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method.
78
- * Standard request base model for all language inputs.
79
- * ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
80
- * Diffusion model support
81
- * Together model support
82
-
83
- Removed
84
- ^^^^^^^
85
- * `Prompt` class
86
- * `OPT` client - OPT is now available in HuggingFace
87
-
88
- 0.0.1 - 2022-11-08
89
- -------------------
90
- First major pip release of Manifest. Install via `pip install manifest-ml`.
91
-
92
-
93
- .. _@lorr1: https://github.com/lorr1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/Makefile DELETED
@@ -1,27 +0,0 @@
1
- dev:
2
- pip install -e .[all]
3
- pre-commit install
4
-
5
- test: dev check
6
- pytest tests
7
-
8
- format:
9
- isort --atomic manifest/ tests/ web_app/
10
- black manifest/ tests/ web_app/
11
-
12
- check:
13
- isort -c manifest/ tests/ web_app/
14
- black manifest/ tests/ web_app/ --check
15
- flake8 manifest/ tests/ web_app/
16
- mypy manifest/ tests/ web_app/
17
-
18
- clean:
19
- pip uninstall -y manifest
20
- rm -rf src/manifest.egg-info
21
- rm -rf build/ dist/
22
-
23
- prune:
24
- @bash -c "git fetch -p";
25
- @bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done";
26
-
27
- .PHONY: dev test clean check prune
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/README.md DELETED
@@ -1,304 +0,0 @@
1
- # Manifest
2
- How to make prompt programming with Foundation Models a little easier.
3
-
4
-
5
- # Table of Contents
6
- - [Install](#install)
7
- - [Getting Started](#getting-started)
8
- - [Manifest](#manifest-components)
9
- - [Other Models Types](#other-models)
10
- - [Local HuggingFace Models](#local-huggingface-models)
11
- - [Chat Models](#chat-models)
12
- - [Embedding Models](#embedding-models)
13
- - [Road Map](#road-map)
14
- - [Development](#development)
15
- - [Cite](#cite)
16
-
17
-
18
- # Install
19
- Install:
20
- ```bash
21
- pip install manifest-ml
22
- ```
23
-
24
- Install with diffusion support:
25
- ```bash
26
- pip install manifest-ml[diffusers]
27
- ```
28
-
29
- Install with HuggingFace local model support:
30
- ```bash
31
- pip install manifest-ml[api]
32
- ```
33
-
34
- Dev Install:
35
- ```bash
36
- git clone [email protected]:HazyResearch/manifest.git
37
- cd manifest
38
- make dev
39
- ```
40
-
41
- # Getting Started
42
- Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` (or pass key in through variable `client_connection`) then run
43
-
44
- ```python
45
- from manifest import Manifest
46
-
47
- # Start a manifest session to OpenAI - default `engine=text-davinci-003`
48
- manifest = Manifest(
49
- client_name = "openai",
50
- )
51
- manifest.run("Why is the grass green?")
52
- ```
53
-
54
- ## Examples
55
- We have example notebook and python scripts located at [examples](examples). These show how to use different models, model types (i.e. text, diffusers, or embedding models), and async running.
56
-
57
- # Manifest Components
58
- Manifest is meant to be a very light weight package to help with prompt design and iteration. Three key design decisions of Manifest are
59
-
60
- * All models are behind APIs
61
- * Supports caching of model inputs/outputs for iteration, reproducibility, and cost saving
62
- * Unified API to support generate, score, and embed
63
-
64
- ## Models
65
- Manifest provides model clients for [OpenAI](https://openai.com/), [AI21](https://studio.ai21.com/), [Cohere](https://cohere.ai/), [Together](https://together.xyz/), and HuggingFace (see [below](#huggingface-models) for how to use locally hosted HuggingFace models). You can toggle between the models by changing `client_name` and `client_connection`. For example, if a HuggingFace model is loaded locally, run
66
- ```python
67
- manifest = Manifest(
68
- client_name = "huggingface",
69
- client_connection = "http://127.0.0.1:5000",
70
- )
71
- ```
72
- If you want to use Cohere, run
73
- ```python
74
- manifest = Manifest(
75
- client_name = "cohere",
76
- client_connection = <COHERE_API_KEY>,
77
- )
78
- ```
79
- You can also just set `export COHERE_API_KEY=<COHERE_API_KEY>` and not use `client_connection`.
80
-
81
- If you want to use AI21 Labs, run
82
- ```python
83
- manifest = Manifest(
84
- client_name = "ai21",
85
- client_connection = <AI21_API_KEY>,
86
- )
87
- ```
88
-
89
- You can see the model details and possible model inputs to `run()` via
90
- ```python
91
- print(manifest.client_pool.get_current_client().get_model_params())
92
- print(manifest.client_pool.get_current_client().get_model_inputs())
93
- ```
94
-
95
- ## Global Cache
96
- We support having queries and results stored in a global cache that can be shared across users. We treat inputs and outputs as key value pairs and support SQLite or Redis backends. To start with global caching using SQLite, run
97
-
98
- ```python
99
- manifest = Manifest(
100
- client_name = "openai",
101
- cache_name = "sqlite",
102
- cache_connection = "mycache.sqlite",
103
- )
104
- ```
105
- The cache will be saved in `mycache.sqlite`.
106
-
107
- We also support Redis backend.
108
- ```python
109
- manifest = Manifest(
110
- client_name = "openai",
111
- cache_name = "redis",
112
- cache_connection = "localhost:6379"
113
- )
114
- ```
115
- As a hint, if you want to get Redis running, see the `docker run` command below under development.
116
-
117
- ## Running Queries
118
- Once you have a session open, you can write and develop prompts.
119
-
120
- ```python
121
- result = manifest.run("Hello, my name is Laurel")
122
- ```
123
-
124
- You can also run over multiple examples if supported by the client.
125
- ```python
126
- results = manifest.run(["Where are the cats?", "Where are the dogs?"])
127
- ```
128
-
129
- We support async queries as well via
130
- ```python
131
- import asyncio
132
- results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the dogs?"]))
133
- ```
134
-
135
- If something doesn't go right, you can also ask to get a raw manifest Response.
136
- ```python
137
- result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True)
138
- print(result_object.get_request_obj())
139
- print(result_object.is_cached())
140
- print(result_object.get_response_obj())
141
- ```
142
-
143
- By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run`.
144
- ```python
145
- result = manifest.run(prompt, "Laurel", stop_token="and")
146
- ```
147
-
148
- If you want to change default parameters to a model, we pass those as `kwargs` to the client.
149
- ```python
150
- result = manifest.run(prompt, "Laurel", max_tokens=50)
151
- ```
152
-
153
- ## Streaming Queries
154
- Manifest also supports streaming the model response back, assuming it's supported by the underlying client. When calling `run`, pass `stream=True` to get a streaming iterator in response.
155
-
156
- ```python
157
- result_iterator = manifest.run("Tell me a story. Once upon a time", max_tokens=100, stream=True)
158
- for res_text in result_iterator:
159
- print(res_text)
160
- ```
161
- Streaming responses are only supported for single string queries (not batch mode) for text completion models.
162
-
163
- ## Model Pools
164
- Manifest supports querying multiple models with different schedulers. This is very much a work in progress effort, but Manifest will round robin select (or randomly select) the clients you want. You can use the same client multiple times with different connection strings (e.g. different API keys), or you can mix and match. The only requirement is that all clients are the same request type. I.e. you can't have a pool of generation models and embedding models.
165
-
166
- To query between a local model and OpenAI,
167
- ```python
168
- from manifest.connections.client_pool import ClientConnection
169
- from manifest import Manifest
170
-
171
- client_connection1 = ClientConnection(
172
- client_name="huggingface",
173
- client_connection="http://127.0.0.1:5000",
174
- )
175
- client_connection2 = ClientConnection(client_name="openai", engine="text-ada-001")
176
- manifest = Manifest(
177
- client_pool=[client_connection1, client_connection2],
178
- cache_name="sqlite",
179
- client_connection=sqlite_cache,
180
- )
181
- manifest.run(...)
182
- ```
183
-
184
- The speed benefit comes in with async batched runs. When calling `arun_batch` with a list of prompts, Manifest supports a `chunk_size` param. This will break the prompts into `chunk_size` chunks to spread across the client pool. By default `chunk_size` is `-1` which means only one client will get all the prompts to run asynchronously. You must set `chunk_size > 1` to distribute across the pool. There is a further `batch_size` param which control the individual client `batch_size` to send to the model.
185
-
186
- ```python
187
- responses = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))
188
- ```
189
-
190
- # Other Models
191
-
192
- ## Local Huggingface Models
193
- To use a HuggingFace generative model, in `manifest/api` we have a Flask application that hosts the models for you.
194
-
195
- In a separate terminal or Tmux/Screen session, to load 6B parameters models, run
196
- ```bash
197
- python3 -m manifest.api.app \
198
- --model_type huggingface \
199
- --model_name_or_path EleutherAI/gpt-j-6B \
200
- --device 0
201
- ```
202
- You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
203
-
204
- ```python
205
- manifest = Manifest(
206
- client_name = "huggingface",
207
- client_connection = "http://127.0.0.1:5000",
208
- )
209
- ```
210
-
211
- If you have a custom model you trained, pass the model path to `--model_name_or_path`.
212
-
213
- To help load larger models, we also support using `parallelize()` from HF, [accelerate](https://huggingface.co/docs/accelerate/index), [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), and [deepspeed](https://github.com/microsoft/DeepSpeed). You will need to install these packages first via `pip install manifest-ml[api]`. We list the commands to load larger models below.
214
-
215
- * T0pp
216
- ```bash
217
- python3 -m manifest.api.app \
218
- --model_type huggingface \
219
- --model_name_or_path bigscience/T0pp \
220
- --use_hf_parallelize
221
- ```
222
-
223
- * NeoX 20B (requires at least 60GB of GPU memory)
224
- ```bash
225
- python3 -m manifest.api.app \
226
- --model_type huggingface \
227
- --model_name_or_path EleutherAI/gpt-neox-20b \
228
- --use_accelerate_multigpu \
229
- --percent_max_gpu_mem_reduction 0.75
230
- ```
231
- * Bloom 175B (requires at least 240GB of GPU memory)
232
- ```bash
233
- python3 -m manifest.api.app \
234
- --model_type huggingface \
235
- --model_name_or_path bigscience/bloom \
236
- --use_bitsandbytes \
237
- --percent_max_gpu_mem_reduction 0.85
238
- ```
239
-
240
- ## Chat Models
241
- Manifest has specific support for executing against chat models in the more standard "system" / "user" dialogue. To pass in a dialogue history to Manifest, use the `run` command with a list of dictionary inputs with `role` and `content` keys using an associated chat model such as `openaichat`.
242
-
243
- ```python
244
- manifest = Manifest(client_name="openaichat")
245
- dialogue = [
246
- {"role": "system", "content": "You are a helpful assistant who also responds in rhymes"},
247
- {"role": "user", "content": "What is the date?"},
248
- ]
249
- res = manifest.run(dialogue, max_tokens=100)
250
- ```
251
-
252
- ## Embedding Models
253
- Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`.
254
-
255
- To use OpenAI's embedding models, simply run
256
- ```python
257
- manifest = Manifest(client_name="openaiembedding")
258
- embedding_as_np = manifest.run("Get me an embedding for a bunny")
259
- ```
260
-
261
- As explained above, you can load local HuggingFace models that give you embeddings, too. If you want to use a standard generative model, load the model as above use use `client_name="huggingfaceembedding"`. If you want to use a standard embedding model, like those from SentenceTransformers, load your local model via
262
- ```bash
263
- python3 -m manifest.api.app \
264
- --model_type sentence_transformers \
265
- --model_name_or_path all-mpnet-base-v2 \
266
- --device 0
267
- ```
268
-
269
- # Road Map
270
- Here's what's coming up next
271
- - [ ] Clients
272
- - [ ] HuggingFace Hub
273
- - [x] Azure OpenAI
274
- - [x] Google Vertex
275
- - [ ] Anthropic
276
- - [x] Streaming Support Completions
277
- - [ ] Streaming Support Chat Models
278
- - [ ] Data Types
279
- - [ ] Diffusion Models
280
- - [x] Orchestration
281
- - [x] Connection pools
282
- - [ ] Local Inference
283
- - [ ] FlexGen
284
-
285
- # Development
286
- Before submitting a PR, run
287
- ```bash
288
- export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests
289
- cd <REDIS_PATH>
290
- docker run -d -p 127.0.0.1:${REDIS_PORT}:6379 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
291
- make test
292
- ```
293
-
294
- # Cite
295
- Please cite Manifest if you used it for any publications. Thanks!!
296
- ```
297
- @misc{orr2022manifest,
298
- author = {Orr, Laurel},
299
- title = {Manifest},
300
- year = {2022},
301
- publisher = {GitHub},
302
- howpublished = {\url{https://github.com/HazyResearch/manifest}},
303
- }
304
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/langchain_chatgpt.ipynb DELETED
@@ -1,455 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "attachments": {},
5
- "cell_type": "markdown",
6
- "id": "b253f4d5",
7
- "metadata": {},
8
- "source": [
9
- "# ChatGPT Clone using TOMA GPT-JT-6B\n",
10
- "(adopted from ChatGPT Clone [notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/chains/chatgpt_clone.ipynb))"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": 1,
16
- "id": "b0302886",
17
- "metadata": {},
18
- "outputs": [
19
- {
20
- "name": "stdout",
21
- "output_type": "stream",
22
- "text": [
23
- "env: TOMA_URL=https://staging.together.xyz/api\n"
24
- ]
25
- }
26
- ],
27
- "source": [
28
- "%env TOMA_URL=https://staging.together.xyz/api"
29
- ]
30
- },
31
- {
32
- "attachments": {},
33
- "cell_type": "markdown",
34
- "id": "93a18ea6",
35
- "metadata": {},
36
- "source": [
37
- "Make sure you have langchain installed and manifest. For the most recent versions, run\n",
38
- "```\n",
39
- "pip install git+https://github.com/hwchase17/langchain.git\n",
40
- "pip install git+https://github.com/HazyResearch/manifest.git\n",
41
- "```"
42
- ]
43
- },
44
- {
45
- "cell_type": "code",
46
- "execution_count": 35,
47
- "id": "a99acd89",
48
- "metadata": {},
49
- "outputs": [
50
- {
51
- "name": "stdout",
52
- "output_type": "stream",
53
- "text": [
54
- "\n",
55
- "\n",
56
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
57
- "Prompt after formatting:\n",
58
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
59
- "\n",
60
- "\n",
61
- "Input: Classes are \"positive\" and \"negative\". For example given\n",
62
- "Input: I love this product!\n",
63
- "Output: positive.\n",
64
- "I think this movie was one of the worst of the year. Script was boring!\n",
65
- "Output:\u001b[0m\n",
66
- "\n",
67
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
68
- "negative.\n"
69
- ]
70
- }
71
- ],
72
- "source": [
73
- "from manifest import Manifest\n",
74
- "from langchain.llms.manifest import ManifestWrapper\n",
75
- "from langchain import ConversationChain, LLMChain, PromptTemplate\n",
76
- "from langchain.chains.conversation.memory import ConversationalBufferWindowMemory\n",
77
- "\n",
78
- "\n",
79
- "template = \"\"\"I am a classification model. It will try to classify your input.\n",
80
- "\n",
81
- "{history}\n",
82
- "Input: {human_input}\n",
83
- "Output:\"\"\"\n",
84
- "\n",
85
- "prompt = PromptTemplate(\n",
86
- " input_variables=[\"history\", \"human_input\"], \n",
87
- " template=template\n",
88
- ")\n",
89
- "\n",
90
- "manifest = Manifest(\n",
91
- " client_name=\"toma\",\n",
92
- " engine=\"Together-gpt-JT-6B-v1\",\n",
93
- " max_tokens=150,\n",
94
- " top_p=0.9,\n",
95
- " top_k=40,\n",
96
- " stop_sequences=[\"\\n\"],\n",
97
- ")\n",
98
- "\n",
99
- "chatgpt_chain = LLMChain(\n",
100
- " llm=ManifestWrapper(client=manifest), \n",
101
- " prompt=prompt, \n",
102
- " verbose=True, \n",
103
- " memory=ConversationalBufferWindowMemory(k=8),\n",
104
- ")\n",
105
- "\n",
106
- "output = chatgpt_chain.predict(human_input=\"Classes are \\\"positive\\\" and \\\"negative\\\". For example given\\nInput: I love this product!\\nOutput: positive.\\nI think this movie was one of the worst of the year. Script was boring!\")\n",
107
- "print(output)"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": 36,
113
- "id": "4ef711d6",
114
- "metadata": {},
115
- "outputs": [
116
- {
117
- "name": "stdout",
118
- "output_type": "stream",
119
- "text": [
120
- "\n",
121
- "\n",
122
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
123
- "Prompt after formatting:\n",
124
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
125
- "\n",
126
- "Human: Classes are \"positive\" and \"negative\". For example given\n",
127
- "Input: I love this product!\n",
128
- "Output: positive.\n",
129
- "I think this movie was one of the worst of the year. Script was boring!\n",
130
- "AI: negative.\n",
131
- "Input: So awesome! I wish I could have gone\n",
132
- "Output:\u001b[0m\n",
133
- "\n",
134
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
135
- "positive.\n"
136
- ]
137
- }
138
- ],
139
- "source": [
140
- "output = chatgpt_chain.predict(human_input=\"So awesome! I wish I could have gone\")\n",
141
- "print(output)"
142
- ]
143
- },
144
- {
145
- "cell_type": "code",
146
- "execution_count": 37,
147
- "id": "a5d6dac2",
148
- "metadata": {},
149
- "outputs": [
150
- {
151
- "name": "stdout",
152
- "output_type": "stream",
153
- "text": [
154
- "\n",
155
- "\n",
156
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
157
- "Prompt after formatting:\n",
158
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
159
- "\n",
160
- "Human: Classes are \"positive\" and \"negative\". For example given\n",
161
- "Input: I love this product!\n",
162
- "Output: positive.\n",
163
- "I think this movie was one of the worst of the year. Script was boring!\n",
164
- "AI: negative.\n",
165
- "Human: So awesome! I wish I could have gone\n",
166
- "AI: positive.\n",
167
- "Input: Hate it.\n",
168
- "Output:\u001b[0m\n",
169
- "\n",
170
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
171
- "negative.\n"
172
- ]
173
- }
174
- ],
175
- "source": [
176
- "output = chatgpt_chain.predict(human_input=\"Hate it.\")\n",
177
- "print(output)"
178
- ]
179
- },
180
- {
181
- "cell_type": "code",
182
- "execution_count": 43,
183
- "id": "b9283077",
184
- "metadata": {},
185
- "outputs": [
186
- {
187
- "name": "stdout",
188
- "output_type": "stream",
189
- "text": [
190
- "\n",
191
- "\n",
192
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
193
- "Prompt after formatting:\n",
194
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
195
- "\n",
196
- "\n",
197
- "Input: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
198
- "Input: This fruit rippens off of the tree.\n",
199
- "Output: banana.\n",
200
- "Often comes in bosc and bartlett varieties.\n",
201
- "Output:\u001b[0m\n",
202
- "\n",
203
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
204
- "apple.\n"
205
- ]
206
- }
207
- ],
208
- "source": [
209
- "chatgpt_chain.memory.clear()\n",
210
- "output = chatgpt_chain.predict(human_input=\"Classes are fruits \\\"apple\\\", \\\"banana\\\", \\\"orange\\\", \\\"pear\\\". For example given\\nInput: This fruit rippens off of the tree.\\nOutput: banana.\\nOften comes in bosc and bartlett varieties.\")\n",
211
- "print(output)"
212
- ]
213
- },
214
- {
215
- "cell_type": "code",
216
- "execution_count": 44,
217
- "id": "cd0a23d9",
218
- "metadata": {
219
- "scrolled": true
220
- },
221
- "outputs": [
222
- {
223
- "name": "stdout",
224
- "output_type": "stream",
225
- "text": [
226
- "\n",
227
- "\n",
228
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
229
- "Prompt after formatting:\n",
230
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
231
- "\n",
232
- "Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
233
- "Input: This fruit rippens off of the tree.\n",
234
- "Output: banana.\n",
235
- "Often comes in bosc and bartlett varieties.\n",
236
- "AI: apple.\n",
237
- "Input: Often associated with monkeys\n",
238
- "Output:\u001b[0m\n",
239
- "\n",
240
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
241
- "banana.\n"
242
- ]
243
- }
244
- ],
245
- "source": [
246
- "output = chatgpt_chain.predict(human_input=\"Often associated with monkeys\")\n",
247
- "print(output)"
248
- ]
249
- },
250
- {
251
- "cell_type": "code",
252
- "execution_count": 45,
253
- "id": "90db6eb2",
254
- "metadata": {},
255
- "outputs": [
256
- {
257
- "name": "stdout",
258
- "output_type": "stream",
259
- "text": [
260
- "\n",
261
- "\n",
262
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
263
- "Prompt after formatting:\n",
264
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
265
- "\n",
266
- "Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
267
- "Input: This fruit rippens off of the tree.\n",
268
- "Output: banana.\n",
269
- "Often comes in bosc and bartlett varieties.\n",
270
- "AI: apple.\n",
271
- "Human: Often associated with monkeys\n",
272
- "AI: banana.\n",
273
- "Input: Is the color red and often delicious.\n",
274
- "Output:\u001b[0m\n",
275
- "\n",
276
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
277
- "apple.\n"
278
- ]
279
- }
280
- ],
281
- "source": [
282
- "output = chatgpt_chain.predict(human_input=\"Is the color red and often delicious.\")\n",
283
- "print(output)"
284
- ]
285
- },
286
- {
287
- "cell_type": "code",
288
- "execution_count": 48,
289
- "id": "c3806f89",
290
- "metadata": {},
291
- "outputs": [
292
- {
293
- "name": "stdout",
294
- "output_type": "stream",
295
- "text": [
296
- "\n",
297
- "\n",
298
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
299
- "Prompt after formatting:\n",
300
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
301
- "\n",
302
- "\n",
303
- "Input: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
304
- "Input: The color of a school bus.\n",
305
- "Output: yellow.\n",
306
- "Is the color of the sky\n",
307
- "Output:\u001b[0m\n",
308
- "\n",
309
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
310
- "blue.\n"
311
- ]
312
- }
313
- ],
314
- "source": [
315
- "chatgpt_chain.memory.clear()\n",
316
- "output = chatgpt_chain.predict(human_input=\"Classes are colors \\\"red\\\", \\\"green\\\", \\\"blue\\\", \\\"yellow\\\". For example given\\nInput: The color of a school bus.\\nOutput: yellow.\\nIs the color of the sky\")\n",
317
- "print(output)"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": 49,
323
- "id": "f508f597",
324
- "metadata": {},
325
- "outputs": [
326
- {
327
- "name": "stdout",
328
- "output_type": "stream",
329
- "text": [
330
- "\n",
331
- "\n",
332
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
333
- "Prompt after formatting:\n",
334
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
335
- "\n",
336
- "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
337
- "Input: The color of a school bus.\n",
338
- "Output: yellow.\n",
339
- "Is the color of the sky\n",
340
- "AI: blue.\n",
341
- "Input: Color of a banana.\n",
342
- "Output:\u001b[0m\n",
343
- "\n",
344
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
345
- "yellow.\n"
346
- ]
347
- }
348
- ],
349
- "source": [
350
- "output = chatgpt_chain.predict(human_input=\"Color of a banana.\")\n",
351
- "print(output)"
352
- ]
353
- },
354
- {
355
- "cell_type": "code",
356
- "execution_count": 50,
357
- "id": "cbd607f4",
358
- "metadata": {},
359
- "outputs": [
360
- {
361
- "name": "stdout",
362
- "output_type": "stream",
363
- "text": [
364
- "\n",
365
- "\n",
366
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
367
- "Prompt after formatting:\n",
368
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
369
- "\n",
370
- "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
371
- "Input: The color of a school bus.\n",
372
- "Output: yellow.\n",
373
- "Is the color of the sky\n",
374
- "AI: blue.\n",
375
- "Human: Color of a banana.\n",
376
- "AI: yellow.\n",
377
- "Input: When someone is sick they are this color.\n",
378
- "Output:\u001b[0m\n",
379
- "\n",
380
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
381
- "green.\n"
382
- ]
383
- }
384
- ],
385
- "source": [
386
- "output = chatgpt_chain.predict(human_input=\"When someone is sick they are this color.\")\n",
387
- "print(output)"
388
- ]
389
- },
390
- {
391
- "cell_type": "code",
392
- "execution_count": 51,
393
- "id": "d33e0e28",
394
- "metadata": {},
395
- "outputs": [
396
- {
397
- "name": "stdout",
398
- "output_type": "stream",
399
- "text": [
400
- "\n",
401
- "\n",
402
- "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
403
- "Prompt after formatting:\n",
404
- "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
405
- "\n",
406
- "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
407
- "Input: The color of a school bus.\n",
408
- "Output: yellow.\n",
409
- "Is the color of the sky\n",
410
- "AI: blue.\n",
411
- "Human: Color of a banana.\n",
412
- "AI: yellow.\n",
413
- "Human: When someone is sick they are this color.\n",
414
- "AI: green.\n",
415
- "Input: Color of anger.\n",
416
- "Output:\u001b[0m\n",
417
- "\n",
418
- "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
419
- "red.\n"
420
- ]
421
- }
422
- ],
423
- "source": [
424
- "output = chatgpt_chain.predict(human_input=\"Color of anger.\")\n",
425
- "print(output)"
426
- ]
427
- }
428
- ],
429
- "metadata": {
430
- "kernelspec": {
431
- "display_name": "bootleg",
432
- "language": "python",
433
- "name": "python3"
434
- },
435
- "language_info": {
436
- "codemirror_mode": {
437
- "name": "ipython",
438
- "version": 3
439
- },
440
- "file_extension": ".py",
441
- "mimetype": "text/x-python",
442
- "name": "python",
443
- "nbconvert_exporter": "python",
444
- "pygments_lexer": "ipython3",
445
- "version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
446
- },
447
- "vscode": {
448
- "interpreter": {
449
- "hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
450
- }
451
- }
452
- },
453
- "nbformat": 4,
454
- "nbformat_minor": 5
455
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_async.py DELETED
@@ -1,27 +0,0 @@
1
- import asyncio
2
- import time
3
-
4
- from manifest import Manifest
5
-
6
-
7
- def main():
8
-
9
- manifest = Manifest(
10
- client_name="openaichat",
11
- )
12
-
13
- print("Running in serial")
14
- prompts = [f"Tell me something interesting about {i}" for i in range(50)]
15
- st = time.time()
16
- for pmt in prompts:
17
- _ = manifest.run(pmt)
18
- print(f"For loop: {time.time() - st :.2f}")
19
-
20
- print("Running with async")
21
- st = time.time()
22
- _ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30))
23
- print(f"Async loop: {time.time() - st :.2f}")
24
-
25
-
26
- if __name__ == "__main__":
27
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_azure.ipynb DELETED
@@ -1,149 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "AZURE_KEY = \"API_KEY::URL\"\n",
20
- "OPENAI_KEY = \"sk-XXX\""
21
- ]
22
- },
23
- {
24
- "attachments": {},
25
- "cell_type": "markdown",
26
- "metadata": {},
27
- "source": [
28
- "## Use Azure and OpenAI models"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": null,
34
- "metadata": {},
35
- "outputs": [],
36
- "source": [
37
- "from manifest import Manifest\n",
38
- "from manifest.connections.client_pool import ClientConnection\n",
39
- "from pathlib import Path\n",
40
- "\n",
41
- "cache_path = Path(\"manifest.db\")\n",
42
- "if cache_path.exists():\n",
43
- " cache_path.unlink()\n",
44
- "\n",
45
- "\n",
46
- "azure = ClientConnection(\n",
47
- " client_name=\"azureopenai\",\n",
48
- " client_connection=AZURE_KEY,\n",
49
- " engine=\"text-davinci-003\",\n",
50
- ")\n",
51
- "\n",
52
- "manifest = Manifest(client_pool=[azure], \n",
53
- " cache_name=\"sqlite\",\n",
54
- " cache_connection=\"manifest.db\"\n",
55
- ")\n",
56
- "\n",
57
- "\n",
58
- "openai = ClientConnection(\n",
59
- " client_name=\"openai\",\n",
60
- " client_connection=OPENAI_KEY,\n",
61
- " engine=\"text-davinci-003\",\n",
62
- ")\n",
63
- "\n",
64
- "manifest_openai_nocache = Manifest(client_pool=[openai])\n",
65
- "\n",
66
- "manifest_openai = Manifest(client_pool=[openai], \n",
67
- " cache_name=\"sqlite\",\n",
68
- " cache_connection=\"manifest.db\"\n",
69
- ")"
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": null,
75
- "metadata": {},
76
- "outputs": [],
77
- "source": [
78
- "# Show caches are the same\n",
79
- "text = \"What is the meaning of life?\"\n",
80
- "res = manifest.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
81
- "print(res.get_response())\n",
82
- "print(res.is_cached())\n",
83
- "res2 = manifest_openai.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
84
- "print(res2.is_cached())\n",
85
- "\n",
86
- "assert res2.get_response() == res.get_response()"
87
- ]
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": null,
92
- "metadata": {},
93
- "outputs": [],
94
- "source": [
95
- "azure_chat = ClientConnection(\n",
96
- " client_name=\"azureopenaichat\",\n",
97
- " client_connection=AZURE_KEY,\n",
98
- " engine=\"gpt-3.5-turbo\",\n",
99
- ")\n",
100
- "\n",
101
- "manifest = Manifest(client_pool=[azure_chat])"
102
- ]
103
- },
104
- {
105
- "cell_type": "code",
106
- "execution_count": null,
107
- "metadata": {},
108
- "outputs": [],
109
- "source": [
110
- "print(manifest.run(\"What do you think is the best food?\", max_tokens=100))\n",
111
- "\n",
112
- "chat_dict = [\n",
113
- " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
114
- " {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
115
- " {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
116
- " {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
117
- "]\n",
118
- "print(manifest.run(chat_dict, max_tokens=100))"
119
- ]
120
- }
121
- ],
122
- "metadata": {
123
- "kernelspec": {
124
- "display_name": "manifest",
125
- "language": "python",
126
- "name": "python3"
127
- },
128
- "language_info": {
129
- "codemirror_mode": {
130
- "name": "ipython",
131
- "version": 3
132
- },
133
- "file_extension": ".py",
134
- "mimetype": "text/x-python",
135
- "name": "python",
136
- "nbconvert_exporter": "python",
137
- "pygments_lexer": "ipython3",
138
- "version": "3.10.4"
139
- },
140
- "orig_nbformat": 4,
141
- "vscode": {
142
- "interpreter": {
143
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
144
- }
145
- }
146
- },
147
- "nbformat": 4,
148
- "nbformat_minor": 2
149
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_chatgpt.ipynb DELETED
@@ -1,101 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "OPENAI_KEY = \"sk-XXX\""
20
- ]
21
- },
22
- {
23
- "attachments": {},
24
- "cell_type": "markdown",
25
- "metadata": {},
26
- "source": [
27
- "## Use ChatOpenAI\n",
28
- "\n",
29
- "Set you `OPENAI_API_KEY` environment variable."
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": null,
35
- "metadata": {},
36
- "outputs": [],
37
- "source": [
38
- "from manifest import Manifest\n",
39
- "from manifest.connections.client_pool import ClientConnection\n",
40
- "\n",
41
- "openai_chat = ClientConnection(\n",
42
- " client_name=\"openaichat\",\n",
43
- " client_connection=OPENAI_KEY,\n",
44
- " engine=\"gpt-3.5-turbo\"\n",
45
- ")\n",
46
- "\n",
47
- "manifest = Manifest(client_pool=[openai_chat])"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "execution_count": null,
53
- "metadata": {},
54
- "outputs": [],
55
- "source": [
56
- "# Simple question\n",
57
- "chat_dict = [\n",
58
- " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
59
- " {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
60
- " {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
61
- " {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
62
- "]\n",
63
- "print(manifest.run(chat_dict, max_tokens=100))"
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": null,
69
- "metadata": {},
70
- "outputs": [],
71
- "source": []
72
- }
73
- ],
74
- "metadata": {
75
- "kernelspec": {
76
- "display_name": "manifest",
77
- "language": "python",
78
- "name": "python3"
79
- },
80
- "language_info": {
81
- "codemirror_mode": {
82
- "name": "ipython",
83
- "version": 3
84
- },
85
- "file_extension": ".py",
86
- "mimetype": "text/x-python",
87
- "name": "python",
88
- "nbconvert_exporter": "python",
89
- "pygments_lexer": "ipython3",
90
- "version": "3.10.4"
91
- },
92
- "orig_nbformat": 4,
93
- "vscode": {
94
- "interpreter": {
95
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
96
- }
97
- }
98
- },
99
- "nbformat": 4,
100
- "nbformat_minor": 2
101
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_connection_pool.ipynb DELETED
@@ -1,208 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": 1,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "OPENAI_KEY1 = \"sk-XXX\"\n",
20
- "OPENAI_KEY2 = \"sk-XX\""
21
- ]
22
- },
23
- {
24
- "attachments": {},
25
- "cell_type": "markdown",
26
- "metadata": {},
27
- "source": [
28
- "## Use OpenAI\n",
29
- "\n",
30
- "Set you `OPENAI_API_KEY` environment variable."
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "execution_count": 2,
36
- "metadata": {},
37
- "outputs": [],
38
- "source": [
39
- "from manifest import Manifest\n",
40
- "from manifest.connections.client_pool import ClientConnection\n",
41
- "\n",
42
- "openai_ada = ClientConnection(\n",
43
- " client_name=\"openai\",\n",
44
- " client_connection=OPENAI_KEY1,\n",
45
- " engine=\"text-ada-001\"\n",
46
- ")\n",
47
- "\n",
48
- "openai_curie = ClientConnection(\n",
49
- " client_name=\"openai\",\n",
50
- " client_connection=OPENAI_KEY2,\n",
51
- " engine=\"text-curie-001\"\n",
52
- ")\n",
53
- "\n",
54
- "manifest = Manifest(client_pool=[openai_ada, openai_curie], client_pool_schedule=\"round_robin\")"
55
- ]
56
- },
57
- {
58
- "cell_type": "code",
59
- "execution_count": 3,
60
- "metadata": {},
61
- "outputs": [
62
- {
63
- "name": "stdout",
64
- "output_type": "stream",
65
- "text": [
66
- "0\n",
67
- "I am a model.\n",
68
- "1\n",
69
- "I am a MacBook Pro with a retina\n"
70
- ]
71
- }
72
- ],
73
- "source": [
74
- "res = manifest.run(\"What model are you?\", temperature=0.0)\n",
75
- "print(manifest.client_pool.current_client_id)\n",
76
- "print(res)\n",
77
- "res = manifest.run(\"What model are you?\", temperature=0.0)\n",
78
- "print(manifest.client_pool.current_client_id)\n",
79
- "print(res)"
80
- ]
81
- },
82
- {
83
- "attachments": {},
84
- "cell_type": "markdown",
85
- "metadata": {},
86
- "source": [
87
- "## With Async"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": 4,
93
- "metadata": {},
94
- "outputs": [],
95
- "source": [
96
- "import nest_asyncio\n",
97
- "# This is required for asyncio.run(...) to work in Jupyter notebooks.\n",
98
- "nest_asyncio.apply()"
99
- ]
100
- },
101
- {
102
- "cell_type": "code",
103
- "execution_count": 5,
104
- "metadata": {},
105
- "outputs": [],
106
- "source": [
107
- "from manifest import Manifest\n",
108
- "from manifest.connections.client_pool import ClientConnection\n",
109
- "\n",
110
- "openai_ada = ClientConnection(\n",
111
- " client_name=\"openai\",\n",
112
- " client_connection=OPENAI_KEY1,\n",
113
- " engine=\"text-ada-001\"\n",
114
- ")\n",
115
- "\n",
116
- "openai_babbage = ClientConnection(\n",
117
- " client_name=\"openai\",\n",
118
- " client_connection=OPENAI_KEY2,\n",
119
- " engine=\"text-babbage-001\"\n",
120
- ")\n",
121
- "\n",
122
- "openai_curie = ClientConnection(\n",
123
- " client_name=\"openai\",\n",
124
- " client_connection=OPENAI_KEY2,\n",
125
- " engine=\"text-curie-001\"\n",
126
- ")\n",
127
- "\n",
128
- "manifest = Manifest(client_pool=[openai_ada, openai_babbage, openai_curie], client_pool_schedule=\"round_robin\")\n",
129
- "manifest_single_client = Manifest(client_pool=[openai_babbage])"
130
- ]
131
- },
132
- {
133
- "cell_type": "code",
134
- "execution_count": 6,
135
- "metadata": {},
136
- "outputs": [
137
- {
138
- "name": "stdout",
139
- "output_type": "stream",
140
- "text": [
141
- "For loop: 128.68\n",
142
- "Running with async single client\n",
143
- "Running 1 tasks across all clients.\n",
144
- "Async loop: 4.02\n",
145
- "Running with async two clients but not chunking\n",
146
- "Running 1 tasks across all clients.\n",
147
- "Async loop: 3.92\n",
148
- "Running with async two clients and chunk size\n",
149
- "Running 20 tasks across all clients.\n",
150
- "Async loop: 1.44\n"
151
- ]
152
- }
153
- ],
154
- "source": [
155
- "import time\n",
156
- "import asyncio\n",
157
- "\n",
158
- "prompts = [f\"Tell me something interesting about {i}\" for i in range(400)]\n",
159
- "st = time.time()\n",
160
- "for pmt in prompts:\n",
161
- " _ = manifest_single_client.run(pmt, max_tokens=30)\n",
162
- "print(f\"For loop: {time.time() - st :.2f}\")\n",
163
- "\n",
164
- "print(\"Running with async single client\")\n",
165
- "st = time.time()\n",
166
- "_ = asyncio.run(manifest_single_client.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
167
- "print(f\"Async loop: {time.time() - st :.2f}\")\n",
168
- "\n",
169
- "print(\"Running with async two clients but not chunking\")\n",
170
- "st = time.time()\n",
171
- "_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=-1))\n",
172
- "print(f\"Async loop: {time.time() - st :.2f}\")\n",
173
- "\n",
174
- "print(\"Running with async two clients and chunk size\")\n",
175
- "st = time.time()\n",
176
- "_ = asyncio.run(manifest.arun_batch(prompts, max_tokens=30, chunk_size=20))\n",
177
- "print(f\"Async loop: {time.time() - st :.2f}\")"
178
- ]
179
- }
180
- ],
181
- "metadata": {
182
- "kernelspec": {
183
- "display_name": "manifest",
184
- "language": "python",
185
- "name": "python3"
186
- },
187
- "language_info": {
188
- "codemirror_mode": {
189
- "name": "ipython",
190
- "version": 3
191
- },
192
- "file_extension": ".py",
193
- "mimetype": "text/x-python",
194
- "name": "python",
195
- "nbconvert_exporter": "python",
196
- "pygments_lexer": "ipython3",
197
- "version": "3.10.4"
198
- },
199
- "orig_nbformat": 4,
200
- "vscode": {
201
- "interpreter": {
202
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
203
- }
204
- }
205
- },
206
- "nbformat": 4,
207
- "nbformat_minor": 2
208
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_diffusers.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
duckdb-nsql/manifest/examples/manifest_embedding.ipynb DELETED
@@ -1,156 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "attachments": {},
15
- "cell_type": "markdown",
16
- "metadata": {},
17
- "source": [
18
- "## Use OpenAI\n",
19
- "\n",
20
- "Set you `OPENAI_API_KEY` environment variable."
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": 2,
26
- "metadata": {},
27
- "outputs": [
28
- {
29
- "name": "stdout",
30
- "output_type": "stream",
31
- "text": [
32
- "{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n"
33
- ]
34
- }
35
- ],
36
- "source": [
37
- "from manifest import Manifest\n",
38
- "\n",
39
- "manifest = Manifest(client_name=\"openaiembedding\")\n",
40
- "print(manifest.client_pool.get_next_client().get_model_params())"
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 3,
46
- "metadata": {},
47
- "outputs": [
48
- {
49
- "name": "stdout",
50
- "output_type": "stream",
51
- "text": [
52
- "(1536,)\n"
53
- ]
54
- }
55
- ],
56
- "source": [
57
- "emb = manifest.run(\"Is this an embedding?\")\n",
58
- "print(emb.shape)"
59
- ]
60
- },
61
- {
62
- "attachments": {},
63
- "cell_type": "markdown",
64
- "metadata": {},
65
- "source": [
66
- "### Using Locally Hosted Huggingface LM\n",
67
- "\n",
68
- "Run\n",
69
- "```\n",
70
- "python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n",
71
- "```\n",
72
- "or\n",
73
- "```\n",
74
- "python3 manifest/api/app.py --model_type sentence_transformers --model_name_or_path all-mpnet-base-v2 --device 0\n",
75
- "```\n",
76
- "\n",
77
- "in a separate `screen` or `tmux`. Make sure to note the port. You can change this with `export FLASK_PORT=<port>`."
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 1,
83
- "metadata": {},
84
- "outputs": [
85
- {
86
- "name": "stdout",
87
- "output_type": "stream",
88
- "text": [
89
- "{'model_name': 'all-mpnet-base-v2', 'model_path': 'all-mpnet-base-v2', 'client_name': 'huggingfaceembedding'}\n"
90
- ]
91
- }
92
- ],
93
- "source": [
94
- "from manifest import Manifest\n",
95
- "\n",
96
- "# Local hosted GPT Neo 125M\n",
97
- "manifest = Manifest(\n",
98
- " client_name=\"huggingfaceembedding\",\n",
99
- " client_connection=\"http://127.0.0.1:6000\",\n",
100
- " cache_name=\"sqlite\",\n",
101
- " cache_connection=\"my_sqlite_manifest.sqlite\"\n",
102
- ")\n",
103
- "print(manifest.client_pool.get_next_client().get_model_params())"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 4,
109
- "metadata": {},
110
- "outputs": [
111
- {
112
- "name": "stdout",
113
- "output_type": "stream",
114
- "text": [
115
- "(768,)\n",
116
- "(768,) (768,)\n"
117
- ]
118
- }
119
- ],
120
- "source": [
121
- "emb = manifest.run(\"Is this an embedding?\")\n",
122
- "print(emb.shape)\n",
123
- "\n",
124
- "emb = manifest.run([\"Is this an embedding?\", \"Bananas!!!\"])\n",
125
- "print(emb[0].shape, emb[1].shape)"
126
- ]
127
- }
128
- ],
129
- "metadata": {
130
- "kernelspec": {
131
- "display_name": "manifest",
132
- "language": "python",
133
- "name": "python3"
134
- },
135
- "language_info": {
136
- "codemirror_mode": {
137
- "name": "ipython",
138
- "version": 3
139
- },
140
- "file_extension": ".py",
141
- "mimetype": "text/x-python",
142
- "name": "python",
143
- "nbconvert_exporter": "python",
144
- "pygments_lexer": "ipython3",
145
- "version": "3.10.4"
146
- },
147
- "orig_nbformat": 4,
148
- "vscode": {
149
- "interpreter": {
150
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
151
- }
152
- }
153
- },
154
- "nbformat": 4,
155
- "nbformat_minor": 2
156
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_google.ipynb DELETED
@@ -1,117 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "GOOGLE_KEY = \"KEY::PROJECT_ID\""
20
- ]
21
- },
22
- {
23
- "attachments": {},
24
- "cell_type": "markdown",
25
- "metadata": {},
26
- "source": [
27
- "## Use GoogleVertexAPI"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "execution_count": null,
33
- "metadata": {},
34
- "outputs": [],
35
- "source": [
36
- "from manifest import Manifest\n",
37
- "from manifest.connections.client_pool import ClientConnection\n",
38
- "\n",
39
- "google_bison = ClientConnection(\n",
40
- " client_name=\"google\",\n",
41
- " client_connection=GOOGLE_KEY\n",
42
- ")\n",
43
- "\n",
44
- "manifest = Manifest(client_pool=[google_bison])"
45
- ]
46
- },
47
- {
48
- "cell_type": "code",
49
- "execution_count": null,
50
- "metadata": {},
51
- "outputs": [],
52
- "source": [
53
- "# Simple question\n",
54
- "print(manifest.run(\"What is your name\", max_tokens=40))"
55
- ]
56
- },
57
- {
58
- "cell_type": "code",
59
- "execution_count": null,
60
- "metadata": {},
61
- "outputs": [],
62
- "source": [
63
- "from manifest import Manifest\n",
64
- "from manifest.connections.client_pool import ClientConnection\n",
65
- "\n",
66
- "google_bison = ClientConnection(\n",
67
- " client_name=\"googlechat\",\n",
68
- " client_connection=GOOGLE_KEY\n",
69
- ")\n",
70
- "\n",
71
- "manifest = Manifest(client_pool=[google_bison])"
72
- ]
73
- },
74
- {
75
- "cell_type": "code",
76
- "execution_count": null,
77
- "metadata": {},
78
- "outputs": [],
79
- "source": [
80
- "chat_dict = [\n",
81
- " # {\"author\": \"bot\", \"content\": \"You are a helpful assistant.\"},\n",
82
- " {\"author\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
83
- " {\"author\": \"bot\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
84
- " {\"author\": \"user\", \"content\": \"Where was it played?\"}\n",
85
- "]\n",
86
- "print(manifest.run(chat_dict, max_tokens=8))"
87
- ]
88
- }
89
- ],
90
- "metadata": {
91
- "kernelspec": {
92
- "display_name": "manifest",
93
- "language": "python",
94
- "name": "python3"
95
- },
96
- "language_info": {
97
- "codemirror_mode": {
98
- "name": "ipython",
99
- "version": 3
100
- },
101
- "file_extension": ".py",
102
- "mimetype": "text/x-python",
103
- "name": "python",
104
- "nbconvert_exporter": "python",
105
- "pygments_lexer": "ipython3",
106
- "version": "3.10.4"
107
- },
108
- "orig_nbformat": 4,
109
- "vscode": {
110
- "interpreter": {
111
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
112
- }
113
- }
114
- },
115
- "nbformat": 4,
116
- "nbformat_minor": 2
117
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_openrouter.ipynb DELETED
@@ -1,108 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": 4,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "OPENROUTER_API_KEY = \"sk-...\""
20
- ]
21
- },
22
- {
23
- "attachments": {},
24
- "cell_type": "markdown",
25
- "metadata": {},
26
- "source": [
27
- "## Use ChatOpenAI\n",
28
- "\n",
29
- "Set you `OPENROUTER_API_KEY` environment variable."
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": 5,
35
- "metadata": {},
36
- "outputs": [],
37
- "source": [
38
- "from manifest import Manifest\n",
39
- "from manifest.connections.client_pool import ClientConnection\n",
40
- "\n",
41
- "openai_chat = ClientConnection(\n",
42
- " client_name=\"openrouter\",\n",
43
- " client_connection=OPENROUTER_API_KEY,\n",
44
- " engine=\"meta-llama/codellama-70b-instruct\"\n",
45
- ")\n",
46
- "\n",
47
- "manifest = Manifest(client_pool=[openai_chat])"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "execution_count": 6,
53
- "metadata": {},
54
- "outputs": [
55
- {
56
- "name": "stdout",
57
- "output_type": "stream",
58
- "text": [
59
- "2020 World Series was played at the Globe Life Field in Arlington, Texas.\n"
60
- ]
61
- }
62
- ],
63
- "source": [
64
- "# Simple question\n",
65
- "chat_dict = [\n",
66
- " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
67
- " {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
68
- " {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
69
- " {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
70
- "]\n",
71
- "print(manifest.run(chat_dict, max_tokens=100))"
72
- ]
73
- },
74
- {
75
- "cell_type": "code",
76
- "execution_count": null,
77
- "metadata": {},
78
- "outputs": [],
79
- "source": []
80
- }
81
- ],
82
- "metadata": {
83
- "kernelspec": {
84
- "display_name": "Python 3 (ipykernel)",
85
- "language": "python",
86
- "name": "python3"
87
- },
88
- "language_info": {
89
- "codemirror_mode": {
90
- "name": "ipython",
91
- "version": 3
92
- },
93
- "file_extension": ".py",
94
- "mimetype": "text/x-python",
95
- "name": "python",
96
- "nbconvert_exporter": "python",
97
- "pygments_lexer": "ipython3",
98
- "version": "3.11.5"
99
- },
100
- "vscode": {
101
- "interpreter": {
102
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
103
- }
104
- }
105
- },
106
- "nbformat": 4,
107
- "nbformat_minor": 4
108
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_streaming.ipynb DELETED
@@ -1,105 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "OPENAI_KEY = \"sk-XXX\""
20
- ]
21
- },
22
- {
23
- "attachments": {},
24
- "cell_type": "markdown",
25
- "metadata": {},
26
- "source": [
27
- "## Use ChatOpenAI\n",
28
- "\n",
29
- "Set you `OPENAI_API_KEY` environment variable."
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": null,
35
- "metadata": {},
36
- "outputs": [],
37
- "source": [
38
- "from manifest import Manifest\n",
39
- "from manifest.connections.client_pool import ClientConnection\n",
40
- "\n",
41
- "openai_chat = ClientConnection(\n",
42
- " client_name=\"openaichat\",\n",
43
- " client_connection=OPENAI_KEY,\n",
44
- " engine=\"gpt-3.5-turbo\"\n",
45
- ")\n",
46
- "\n",
47
- "manifest = Manifest(client_pool=[openai_chat])"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "execution_count": null,
53
- "metadata": {},
54
- "outputs": [],
55
- "source": [
56
- "manifest_iterator = manifest.run(\"Tell me a story about a fat cat.\\n\\nOnce upon a time\", max_tokens=200, stream=True)"
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": null,
62
- "metadata": {},
63
- "outputs": [],
64
- "source": [
65
- "import sys\n",
66
- "\n",
67
- "cur_line_length = 0\n",
68
- "# Iterate over stream\n",
69
- "for res in manifest_iterator:\n",
70
- " sys.stdout.write(res)\n",
71
- " cur_line_length += len(res)\n",
72
- " if cur_line_length > 80:\n",
73
- " sys.stdout.write(\"\\n\")\n",
74
- " cur_line_length = 0"
75
- ]
76
- }
77
- ],
78
- "metadata": {
79
- "kernelspec": {
80
- "display_name": "manifest",
81
- "language": "python",
82
- "name": "python3"
83
- },
84
- "language_info": {
85
- "codemirror_mode": {
86
- "name": "ipython",
87
- "version": 3
88
- },
89
- "file_extension": ".py",
90
- "mimetype": "text/x-python",
91
- "name": "python",
92
- "nbconvert_exporter": "python",
93
- "pygments_lexer": "ipython3",
94
- "version": "3.10.4"
95
- },
96
- "orig_nbformat": 4,
97
- "vscode": {
98
- "interpreter": {
99
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
100
- }
101
- }
102
- },
103
- "nbformat": 4,
104
- "nbformat_minor": 2
105
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/examples/manifest_together.ipynb DELETED
@@ -1,106 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stdout",
10
- "output_type": "stream",
11
- "text": [
12
- "env: TOMA_URL=<TOMA_URL>\n"
13
- ]
14
- }
15
- ],
16
- "source": [
17
- "%load_ext autoreload\n",
18
- "%autoreload 2\n",
19
- "\n",
20
- "%env TOMA_URL=<TOMA_URL>"
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": null,
26
- "metadata": {},
27
- "outputs": [],
28
- "source": [
29
- "from manifest import Manifest\n",
30
- "\n",
31
- "# The responses are not fast\n",
32
- "manifest = Manifest(\n",
33
- " client_name=\"toma\",\n",
34
- ")\n",
35
- "\n",
36
- "print(manifest.run(\"What is the color of an apple?\"))"
37
- ]
38
- },
39
- {
40
- "attachments": {},
41
- "cell_type": "markdown",
42
- "metadata": {},
43
- "source": [
44
- "With a cache"
45
- ]
46
- },
47
- {
48
- "cell_type": "code",
49
- "execution_count": null,
50
- "metadata": {},
51
- "outputs": [],
52
- "source": [
53
- "from manifest import Manifest\n",
54
- "\n",
55
- "# The responses are not fast\n",
56
- "manifest = Manifest(\n",
57
- " client_name=\"toma\",\n",
58
- " cache_name=\"sqlite\",\n",
59
- " cache_connection=\"my_manifest_cache.sqlite\",\n",
60
- ")\n",
61
- "\n",
62
- "res = manifest.run(\"What is the color of an apple?\", return_response=True)\n",
63
- "print(res.get_response())\n",
64
- "print(\"Is Cached?\", res.is_cached())\n",
65
- "\n",
66
- "res = manifest.run(\"What is the color of an apple?\", return_response=True)\n",
67
- "print(res.get_response())\n",
68
- "print(\"Is Cached?\", res.is_cached())"
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": null,
74
- "metadata": {},
75
- "outputs": [],
76
- "source": []
77
- }
78
- ],
79
- "metadata": {
80
- "kernelspec": {
81
- "display_name": "manifest",
82
- "language": "python",
83
- "name": "python3"
84
- },
85
- "language_info": {
86
- "codemirror_mode": {
87
- "name": "ipython",
88
- "version": 3
89
- },
90
- "file_extension": ".py",
91
- "mimetype": "text/x-python",
92
- "name": "python",
93
- "nbconvert_exporter": "python",
94
- "pygments_lexer": "ipython3",
95
- "version": "3.10.4"
96
- },
97
- "orig_nbformat": 4,
98
- "vscode": {
99
- "interpreter": {
100
- "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
101
- }
102
- }
103
- },
104
- "nbformat": 4,
105
- "nbformat_minor": 2
106
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- """Manifest init."""
2
- from manifest.manifest import Manifest
3
- from manifest.request import Request
4
- from manifest.response import Response
5
-
6
- __all__ = ["Manifest", "Response", "Request"]
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Api init."""
 
 
duckdb-nsql/manifest/manifest/api/app.py DELETED
@@ -1,301 +0,0 @@
1
- """Flask app."""
2
- import argparse
3
- import io
4
- import json
5
- import logging
6
- import os
7
- import socket
8
- from typing import Dict
9
-
10
- import pkg_resources
11
- from flask import Flask, Response, request
12
-
13
- from manifest.api.models.diffuser import DiffuserModel
14
- from manifest.api.models.huggingface import (
15
- MODEL_GENTYPE_REGISTRY,
16
- CrossModalEncoderModel,
17
- TextGenerationModel,
18
- )
19
- from manifest.api.models.sentence_transformer import SentenceTransformerModel
20
- from manifest.api.response import ModelResponse
21
-
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
-
24
- logger = logging.getLogger(__name__)
25
- app = Flask(__name__) # define app using Flask
26
- # Will be global
27
- model = None
28
- model_type = None
29
- PORT = int(os.environ.get("FLASK_PORT", 5000))
30
- MODEL_CONSTRUCTORS = {
31
- "huggingface": TextGenerationModel,
32
- "sentence_transformers": SentenceTransformerModel,
33
- "huggingface_crossmodal": CrossModalEncoderModel,
34
- "diffuser": DiffuserModel,
35
- }
36
-
37
-
38
- def parse_args() -> argparse.Namespace:
39
- """Generate args."""
40
- parser = argparse.ArgumentParser(description="Model args")
41
- parser.add_argument(
42
- "--model_type",
43
- default=None,
44
- type=str,
45
- required=True,
46
- help="Model type used for finding constructor.",
47
- choices=MODEL_CONSTRUCTORS.keys(),
48
- )
49
- parser.add_argument(
50
- "--model_generation_type",
51
- default=None,
52
- type=str,
53
- help="Model generation type.",
54
- choices=MODEL_GENTYPE_REGISTRY.keys(),
55
- )
56
- parser.add_argument(
57
- "--model_name_or_path",
58
- default=None,
59
- type=str,
60
- help="Name of model or path to model. Used in initialize of model class.",
61
- )
62
- parser.add_argument(
63
- "--cache_dir", default=None, type=str, help="Cache directory for models."
64
- )
65
- parser.add_argument(
66
- "--device", type=int, default=0, help="Model device. -1 for CPU."
67
- )
68
- parser.add_argument(
69
- "--fp16", action="store_true", help="Force use fp16 for model params."
70
- )
71
- parser.add_argument(
72
- "--percent_max_gpu_mem_reduction",
73
- type=float,
74
- default=0.85,
75
- help="Used with accelerate multigpu. Scales down max memory.",
76
- )
77
- parser.add_argument(
78
- "--use_bitsandbytes",
79
- action="store_true",
80
- help=("Use bits and bytes. " "This will override --device parameter."),
81
- )
82
- parser.add_argument(
83
- "--use_accelerate_multigpu",
84
- action="store_true",
85
- help=(
86
- "Use accelerate for multi gpu inference. "
87
- "This will override --device parameter."
88
- ),
89
- )
90
- parser.add_argument(
91
- "--use_hf_parallelize",
92
- action="store_true",
93
- help=(
94
- "Use HF parallelize for multi gpu inference. "
95
- "This will override --device parameter."
96
- ),
97
- )
98
- parser.add_argument(
99
- "--use_deepspeed",
100
- action="store_true",
101
- help=("Use deepspeed. This will override --device parameter."),
102
- )
103
- args = parser.parse_args()
104
- return args
105
-
106
-
107
- def is_port_in_use(port: int) -> bool:
108
- """Check if port is in use."""
109
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
110
- return s.connect_ex(("localhost", port)) == 0
111
-
112
-
113
- def main() -> None:
114
- """Run main."""
115
- kwargs = parse_args()
116
- if is_port_in_use(PORT):
117
- raise ValueError(f"Port {PORT} is already in use.")
118
- global model_type
119
- model_type = kwargs.model_type
120
- model_gen_type = kwargs.model_generation_type
121
- model_name_or_path = kwargs.model_name_or_path
122
- if not model_name_or_path:
123
- raise ValueError("Must provide model_name_or_path.")
124
- if kwargs.use_accelerate_multigpu:
125
- logger.info("Using accelerate. Overridding --device argument.")
126
- if (
127
- kwargs.percent_max_gpu_mem_reduction <= 0
128
- or kwargs.percent_max_gpu_mem_reduction > 1
129
- ):
130
- raise ValueError("percent_max_gpu_mem_reduction must be in (0, 1].")
131
- if (
132
- sum(
133
- [
134
- kwargs.use_accelerate_multigpu,
135
- kwargs.use_hf_parallelize,
136
- kwargs.use_bitsandbytes,
137
- kwargs.use_deepspeed,
138
- ]
139
- )
140
- > 1
141
- ):
142
- raise ValueError(
143
- "Only one of use_accelerate_multigpu, use_hf_parallelize, "
144
- "use_bitsandbytes, and use_deepspeed can be set."
145
- )
146
- # Global model
147
- global model
148
- model = MODEL_CONSTRUCTORS[model_type](
149
- model_name_or_path,
150
- model_type=model_gen_type,
151
- cache_dir=kwargs.cache_dir,
152
- device=kwargs.device,
153
- use_accelerate=kwargs.use_accelerate_multigpu,
154
- use_parallelize=kwargs.use_hf_parallelize,
155
- use_bitsandbytes=kwargs.use_bitsandbytes,
156
- use_deepspeed=kwargs.use_deepspeed,
157
- perc_max_gpu_mem_red=kwargs.percent_max_gpu_mem_reduction,
158
- use_fp16=kwargs.fp16,
159
- )
160
- app.run(host="0.0.0.0", port=PORT)
161
-
162
-
163
- @app.route("/completions", methods=["POST"])
164
- def completions() -> Response:
165
- """Get completions for generation."""
166
- prompt = request.json["prompt"]
167
- del request.json["prompt"]
168
- generation_args = request.json
169
-
170
- if not isinstance(prompt, (str, list)):
171
- raise ValueError("Prompt must be a str or list of str")
172
- try:
173
- result_gens = []
174
- for generations in model.generate(prompt, **generation_args):
175
- result_gens.append(generations)
176
- if model_type == "diffuser":
177
- # Assign None logprob as it's not supported in diffusers
178
- results = [
179
- {"array": r[0], "logprob": None, "tokens": None, "token_logprobs": None}
180
- for r in result_gens
181
- ]
182
- res_type = "image_generation"
183
- else:
184
- results = [
185
- {"text": r[0], "logprob": r[1], "tokens": r[2], "token_logprobs": r[3]}
186
- for r in result_gens
187
- ]
188
- res_type = "text_completion"
189
- # transform the result into the openai format
190
- return Response(
191
- json.dumps(ModelResponse(results, response_type=res_type).__dict__()),
192
- status=200,
193
- )
194
- except Exception as e:
195
- logger.error(e)
196
- return Response(
197
- json.dumps({"message": str(e)}),
198
- status=400,
199
- )
200
-
201
-
202
- @app.route("/embed", methods=["POST"])
203
- def embed() -> Response:
204
- """Get embed for generation."""
205
- if "modality" in request.json:
206
- modality = request.json["modality"]
207
- else:
208
- modality = "text"
209
- if modality == "text":
210
- prompts = request.json["prompt"]
211
- elif modality == "image":
212
- import base64
213
-
214
- from PIL import Image
215
-
216
- prompts = [
217
- Image.open(io.BytesIO(base64.b64decode(data)))
218
- for data in request.json["prompt"]
219
- ]
220
- else:
221
- raise ValueError("modality must be text or image")
222
-
223
- try:
224
- results = []
225
- embeddings = model.embed(prompts)
226
- for embedding in embeddings:
227
- results.append(
228
- {
229
- "array": embedding,
230
- "logprob": None,
231
- "tokens": None,
232
- "token_logprobs": None,
233
- }
234
- )
235
-
236
- return Response(
237
- json.dumps(
238
- ModelResponse(results, response_type="embedding_generation").__dict__()
239
- ),
240
- status=200,
241
- )
242
- except Exception as e:
243
- logger.error(e)
244
- return Response(
245
- json.dumps({"message": str(e)}),
246
- status=400,
247
- )
248
-
249
-
250
- @app.route("/score_sequence", methods=["POST"])
251
- def score_sequence() -> Response:
252
- """Get logprob of prompt."""
253
- prompt = request.json["prompt"]
254
- del request.json["prompt"]
255
- generation_args = request.json
256
-
257
- if not isinstance(prompt, (str, list)):
258
- raise ValueError("Prompt must be a str or list of str")
259
-
260
- try:
261
- score_list = model.score_sequence(prompt, **generation_args)
262
- results = [
263
- {
264
- "text": prompt if isinstance(prompt, str) else prompt[i],
265
- "logprob": r[0],
266
- "tokens": r[1],
267
- "token_logprobs": r[2],
268
- }
269
- for i, r in enumerate(score_list)
270
- ]
271
- # transform the result into the openai format
272
- return Response(
273
- json.dumps(
274
- ModelResponse(results, response_type="prompt_logit_score").__dict__()
275
- ),
276
- status=200,
277
- )
278
- except Exception as e:
279
- logger.error(e)
280
- return Response(
281
- json.dumps({"message": str(e)}),
282
- status=400,
283
- )
284
-
285
-
286
- @app.route("/params", methods=["POST"])
287
- def params() -> Dict:
288
- """Get model params."""
289
- return model.get_init_params()
290
-
291
-
292
- @app.route("/")
293
- def index() -> str:
294
- """Get index completion."""
295
- fn = pkg_resources.resource_filename("metaseq", "service/index.html")
296
- with open(fn) as f:
297
- return f.read()
298
-
299
-
300
- if __name__ == "__main__":
301
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Models init."""
 
 
duckdb-nsql/manifest/manifest/api/models/diffuser.py DELETED
@@ -1,123 +0,0 @@
1
- """Diffuser model."""
2
- from pathlib import Path
3
- from typing import Any, Dict, List, Optional, Tuple, Union
4
-
5
- import numpy as np
6
- import torch
7
- from diffusers import StableDiffusionPipeline
8
-
9
- from manifest.api.models.model import Model
10
-
11
-
12
- class DiffuserModel(Model):
13
- """Diffuser model."""
14
-
15
- def __init__(
16
- self,
17
- model_name_or_path: str,
18
- model_type: Optional[str] = None,
19
- model_config: Optional[str] = None,
20
- cache_dir: Optional[str] = None,
21
- device: int = 0,
22
- use_accelerate: bool = False,
23
- use_parallelize: bool = False,
24
- use_bitsandbytes: bool = False,
25
- use_deepspeed: bool = False,
26
- perc_max_gpu_mem_red: float = 1.0,
27
- use_fp16: bool = False,
28
- ):
29
- """
30
- Initialize model.
31
-
32
- All arguments will be passed in the request from Manifest.
33
-
34
- Args:
35
- model_name_or_path: model name string.
36
- model_config: model config string.
37
- cache_dir: cache directory for model.
38
- device: device to use for model.
39
- use_accelerate: whether to use accelerate for multi-gpu inference.
40
- use_parallelize: use HF default parallelize
41
- use_bitsandbytes: use HF bits and bytes
42
- use_deepspeed: use deepspeed
43
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
44
- use_fp16: use fp16 for model weights.
45
- """
46
- if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
47
- raise ValueError(
48
- "Cannot use accelerate or parallelize or "
49
- "bitsandbytes or deepspeeed with diffusers"
50
- )
51
- # Check if providing path
52
- self.model_path = model_name_or_path
53
- if Path(self.model_path).exists() and Path(self.model_path).is_dir():
54
- model_name_or_path = Path(self.model_path).name
55
- self.model_name = model_name_or_path
56
- print("Model Name:", self.model_name, "Model Path:", self.model_path)
57
- dtype = torch.float16 if use_fp16 else None
58
- torch_device = (
59
- torch.device("cpu")
60
- if (device == -1 or not torch.cuda.is_available())
61
- else torch.device(f"cuda:{device}")
62
- )
63
- self.pipeline = StableDiffusionPipeline.from_pretrained(
64
- self.model_path,
65
- torch_dtype=dtype,
66
- revision="fp16" if str(dtype) == "float16" else None,
67
- )
68
- self.pipeline.safety_checker = None
69
- self.pipeline.to(torch_device)
70
-
71
- def get_init_params(self) -> Dict:
72
- """Return init params to determine what model is being used."""
73
- return {"model_name": self.model_name, "model_path": self.model_path}
74
-
75
- @torch.no_grad()
76
- def generate(
77
- self, prompt: Union[str, List[str]], **kwargs: Any
78
- ) -> List[Tuple[Any, float, List[str], List[float]]]:
79
- """
80
- Generate the prompt from model.
81
-
82
- Outputs must be generated text and score, not including prompt.
83
-
84
- Args:
85
- prompt: promt to generate from.
86
-
87
- Returns:
88
- list of generated text (list of length 1 for 1 generation).
89
- """
90
- # TODO: Is this correct for getting arguments in?
91
- if isinstance(prompt, str):
92
- prompt = [prompt]
93
- result = self.pipeline(prompt, output_type="np.array", **kwargs)
94
- # Return None for logprobs and token logprobs
95
- return [(im, None, None, None) for im in result["images"]]
96
-
97
- @torch.no_grad()
98
- def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
99
- """
100
- Embed the prompt from model.
101
-
102
- Args:
103
- prompt: promt to embed from.
104
-
105
- Returns:
106
- list of embeddings (list of length 1 for 1 embedding).
107
- """
108
- raise NotImplementedError("Embed not supported for diffusers")
109
-
110
- @torch.no_grad()
111
- def score_sequence(
112
- self, prompt: Union[str, List[str]], **kwargs: Any
113
- ) -> List[Tuple[float, List[int], List[float]]]:
114
- """
115
- Score a sequence of choices.
116
-
117
- Args:
118
- prompt (:obj:`str` or :obj:`List[str]`):
119
- The prompt to score the choices against.
120
- **kwargs:
121
- Additional keyword arguments passed along to the :obj:`__call__` method.
122
- """
123
- raise NotImplementedError("Score sequence not supported for diffusers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/models/huggingface.py DELETED
@@ -1,671 +0,0 @@
1
- """Huggingface model."""
2
- import json
3
- from pathlib import Path
4
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
5
-
6
- import deepspeed
7
- import numpy as np
8
- import PIL
9
- import torch
10
- from accelerate import dispatch_model, infer_auto_device_map
11
- from accelerate.utils.modeling import get_max_memory as acc_get_max_memory
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoModelForSeq2SeqLM,
15
- AutoTokenizer,
16
- BloomForCausalLM,
17
- CLIPModel,
18
- CLIPProcessor,
19
- GPT2LMHeadModel,
20
- GPTJForCausalLM,
21
- GPTNeoForCausalLM,
22
- GPTNeoXForCausalLM,
23
- LlamaForCausalLM,
24
- LlamaTokenizer,
25
- OPTForCausalLM,
26
- PreTrainedModel,
27
- PreTrainedTokenizer,
28
- )
29
-
30
- from manifest.api.models.model import Model
31
-
32
- MODEL_REGISTRY = {
33
- "EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
34
- "EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
35
- "EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
36
- "EleutherAI/gpt-j-6B": GPTJForCausalLM,
37
- "EleutherAI/gpt-neox-20b": GPTNeoXForCausalLM,
38
- "facebook/opt-125m": OPTForCausalLM,
39
- "facebook/opt-350m": OPTForCausalLM,
40
- "Salesforce/codegen-2B-mono": AutoModelForCausalLM,
41
- "Salesforce/codegen-6B-mono": AutoModelForCausalLM,
42
- "facebook/opt-1.3b": OPTForCausalLM,
43
- "facebook/opt-2.7b": OPTForCausalLM,
44
- "facebook/opt-6.7b": OPTForCausalLM,
45
- "facebook/opt-13b": OPTForCausalLM,
46
- "facebook/opt-30b": OPTForCausalLM,
47
- "gpt2": GPT2LMHeadModel,
48
- "openai/clip-vit-base-patch32": CLIPModel,
49
- "bigscience/bloom-560m": BloomForCausalLM,
50
- "bigscience/bloom-1b7": BloomForCausalLM,
51
- "bigscience/bloom-3b": BloomForCausalLM,
52
- "bigscience/bloom-7b1": BloomForCausalLM,
53
- "chainyo/alpaca-lora-7b": LlamaForCausalLM,
54
- "bigscience/bloom": AutoModelForCausalLM,
55
- "bigscience/T0pp": AutoModelForSeq2SeqLM,
56
- "bigscience/T0_3B": AutoModelForSeq2SeqLM,
57
- "google/t5-small-lm-adapt": AutoModelForSeq2SeqLM, # 220M
58
- "google/t5-l-lm-adapt": AutoModelForSeq2SeqLM, # 800M
59
- "google/t5-xl-lm-adapt": AutoModelForSeq2SeqLM, # 3B
60
- "google/t5-xxl-lm-adapt": AutoModelForSeq2SeqLM, # 11B
61
- "google/t5-v1_1-l": AutoModelForSeq2SeqLM, # 800M
62
- "google/t5-v1_1-xl": AutoModelForSeq2SeqLM, # 3B
63
- "google/t5-v1_1-xxl": AutoModelForSeq2SeqLM, # 11B
64
- "google/flan-t5-l": AutoModelForSeq2SeqLM, # 800M
65
- "google/flan-t5-xl": AutoModelForSeq2SeqLM, # 3B
66
- "google/flan-t5-xxl": AutoModelForSeq2SeqLM, # 11B
67
- }
68
-
69
- MODEL_GENTYPE_REGISTRY = {
70
- "text-generation": AutoModelForCausalLM,
71
- "llama-text-generation": LlamaForCausalLM,
72
- "text2text-generation": AutoModelForSeq2SeqLM,
73
- }
74
-
75
-
76
- def get_max_memory(gpu_reduction: float) -> Dict[int, str]:
77
- """Get max memory in GB times reduction."""
78
- free_in_gb = int(torch.cuda.mem_get_info()[0] / 1024**3) # type: ignore
79
- max_mem = f"{int(gpu_reduction*free_in_gb)}GB"
80
-
81
- n_gpus = torch.cuda.device_count()
82
- max_mem_dict = {i: max_mem for i in range(n_gpus)}
83
- return max_mem_dict
84
-
85
-
86
- class GenerationPipeline:
87
- """
88
- Custom Pipeline.
89
-
90
- HF pipelines do not handle devices well in multi-gpu setting.
91
- Create our own generation pipeline.
92
- """
93
-
94
- def __init__(
95
- self,
96
- model: Union[PreTrainedModel, deepspeed.InferenceEngine],
97
- tokenizer: PreTrainedTokenizer,
98
- device: int = None,
99
- bitsandbytes: bool = False,
100
- is_encdec: bool = False,
101
- ):
102
- """Initialize."""
103
- # Use to turn off sampling
104
- # https://github.com/TimDettmers/bitsandbytes/issues/42
105
- self.bitsandbytes = bitsandbytes
106
- self.model = model
107
- self.is_encdec = is_encdec
108
- config = model.config # type: ignore
109
- # Used for GPT
110
- self.max_length = getattr(config, "max_position_embeddings", None)
111
- if self.max_length is None:
112
- # Used for Bloom
113
- self.max_length = getattr(config, "seq_length", None)
114
- if self.max_length is None:
115
- # Used for T0
116
- self.max_length = getattr(config, "d_model", None)
117
- if self.max_length is None:
118
- # Default
119
- self.max_length = 2048
120
-
121
- print(f"Usings max_length: {self.max_length}")
122
-
123
- self.tokenizer = tokenizer
124
- # self.device = device
125
- # With bits and bytes, do not want to place inputs on any device
126
- # if self.device:
127
- self.device = (
128
- torch.device("cpu")
129
- if (device == -1 or not torch.cuda.is_available())
130
- else torch.device(f"cuda:{device}")
131
- )
132
-
133
- def __call__(
134
- self, text: Union[str, List[str]], **kwargs: Any
135
- ) -> List[Dict[str, Union[str, List[float], List[str]]]]:
136
- """Generate from text.
137
-
138
- Args:
139
- text: text to generate.
140
-
141
- Returns:
142
- generated text.
143
- """
144
- # If text is longer than max model length, we reduce max input length to ensure
145
- # the user indicated generation tokens is preserved.
146
- max_input_len = (
147
- self.max_length - kwargs.get("max_new_tokens")
148
- if not self.is_encdec
149
- else self.max_length
150
- )
151
- encoded_prompt = self.tokenizer(
152
- text,
153
- max_length=max_input_len,
154
- truncation=True,
155
- padding=True,
156
- return_tensors="pt",
157
- )
158
- encoded_prompt = encoded_prompt.to(self.device)
159
- kwargs_to_pass = dict(
160
- temperature=kwargs.get("temperature"),
161
- top_k=kwargs.get("top_k"),
162
- top_p=kwargs.get("top_p"),
163
- repetition_penalty=kwargs.get("repetition_penalty"),
164
- num_return_sequences=kwargs.get("num_return_sequences"),
165
- do_sample=kwargs.get("do_sample"),
166
- )
167
- kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
168
- output_dict = self.model.generate( # type: ignore
169
- **encoded_prompt,
170
- **kwargs_to_pass,
171
- max_new_tokens=kwargs.get("max_new_tokens"),
172
- eos_token_id=self.tokenizer.eos_token_id,
173
- pad_token_id=self.tokenizer.pad_token_id,
174
- output_scores=True,
175
- return_dict_in_generate=True,
176
- )
177
- # logits/scores from the output always correspond to the generated tokens.
178
- # shape (num_tokens, num_return_sequences, vocab_size)
179
- logits = torch.stack(output_dict.scores)
180
- logits = torch.nn.functional.log_softmax(logits, dim=-1)
181
- num_generated_tokens = logits.shape[0]
182
- generated_sequences = [
183
- {
184
- "generated_text": self.tokenizer.decode(
185
- output_seq[-num_generated_tokens:], skip_special_tokens=True
186
- ),
187
- "logprobs": logits[
188
- range(num_generated_tokens), i, output_seq[-num_generated_tokens:]
189
- ].tolist(),
190
- "tokens": self.tokenizer.convert_ids_to_tokens(
191
- output_seq[-num_generated_tokens:].tolist()
192
- ),
193
- }
194
- for i, output_seq in enumerate(output_dict.sequences)
195
- ]
196
- return generated_sequences
197
-
198
-
199
- class HuggingFaceModel(Model):
200
- """HuggingFace Model."""
201
-
202
- def __init__(
203
- self,
204
- model_name_or_path: str,
205
- model_type: Optional[str] = None,
206
- model_config: Optional[str] = None,
207
- cache_dir: Optional[str] = None,
208
- device: int = 0,
209
- use_accelerate: bool = False,
210
- use_parallelize: bool = False,
211
- use_bitsandbytes: bool = False,
212
- use_deepspeed: bool = False,
213
- perc_max_gpu_mem_red: float = 1.0,
214
- use_fp16: bool = False,
215
- ):
216
- """
217
- Initialize model.
218
-
219
- All arguments will be passed in the request from Manifest.
220
-
221
- Args:
222
- model_name_or_path: model name string.
223
- model_config: model config string.
224
- cache_dir: cache directory for model.
225
- device: device to use for model.
226
- use_accelerate: whether to use accelerate for multi-gpu inference.
227
- use_parallelize: use HF default parallelize
228
- use_bitsandbytes: use HF bits and bytes
229
- use_deepspeed: use deepspeed
230
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
231
- use_fp16: use fp16 for model weights.
232
- """
233
- if sum([use_accelerate, use_parallelize, use_bitsandbytes, use_deepspeed]) > 1:
234
- raise ValueError(
235
- "Only one of use_accelerate, use_parallelize, "
236
- "use_bitsandbytes, use_deepspeed can be set to True"
237
- )
238
- # Check if providing path
239
- self.model_path = model_name_or_path
240
- if Path(self.model_path).exists() and Path(self.model_path).is_dir():
241
- # Try to find config
242
- if (Path(self.model_path) / "config.json").exists():
243
- config = json.load(open(Path(self.model_path) / "config.json"))
244
- model_name_or_path = config["_name_or_path"]
245
- self.model_name = model_name_or_path
246
- self.model_type = model_type
247
- if self.model_name not in MODEL_REGISTRY and self.model_type is None:
248
- raise ValueError(
249
- f"{self.model_name} is not in our registry. Please specify "
250
- "--model_generation_type as either text-generation (for Causal)"
251
- " or text2text-generation (for Seq2Seq)"
252
- )
253
- print("Model Name:", self.model_name, "Model Path:", self.model_path)
254
-
255
- def get_init_params(self) -> Dict:
256
- """Return init params to determine what model is being used."""
257
- return {"model_name": self.model_name, "model_path": self.model_path}
258
-
259
- def _dispatch_deepspeed_model(
260
- self, model: PreTrainedModel
261
- ) -> deepspeed.InferenceEngine:
262
- """
263
- Load model with deepspeed.
264
-
265
- Adapted from https://www.deepspeed.ai/tutorials/inference-tutorial/
266
-
267
- Args:
268
- model: loaded hugging face model
269
- """
270
- model = deepspeed.init_inference(
271
- model=model,
272
- mp_size=1,
273
- dtype=model.dtype,
274
- replace_method="auto",
275
- replace_with_kernel_inject=True,
276
- )
277
- return model
278
-
279
- def _dispatch_accelerate_model(
280
- self, model: PreTrainedModel, perc_max_gpu_mem_red: float
281
- ) -> None:
282
- """
283
- Load model with accelerate.
284
-
285
- Adapted from https://colab.research.google.com/drive/14wnxMvD9zsiBQo2FtT
286
- pxn6w2cpXCcb-7#scrollTo=y8Ne7jJdaF9F&uniqifier=1
287
-
288
- Args:
289
- model: loaded hugging face model
290
- perc_max_gpu_mem_red: percent memory reduction
291
- """
292
- model.tie_weights() # type: ignore
293
- # Get the model where we can infer devices from
294
- if hasattr(model, "model"):
295
- # OPT
296
- main_model = model.model # type: ignore
297
- model_getter = "model."
298
- else:
299
- # Eleuther Neo and J
300
- main_model = model
301
- model_getter = ""
302
- # Decrease max mem
303
- max_memory = {
304
- k: int(perc_max_gpu_mem_red * v) for k, v in acc_get_max_memory().items()
305
- }
306
- raw_device_map = infer_auto_device_map(
307
- main_model,
308
- max_memory=max_memory,
309
- no_split_module_classes=[
310
- "OPTDecoderLayer",
311
- "GPTNeoBlock",
312
- "GPTJBlock",
313
- "GPTNeoXLayer",
314
- "T5Block",
315
- ],
316
- dtype=model.dtype, # type: ignore
317
- )
318
- # Hacky fix for Eleuther getting the "weight" of embeddings
319
- device_map = {}
320
- for k, v in raw_device_map.items():
321
- if k in {"wte", "wpe"}:
322
- device_map[f"{model_getter}{k}.weight"] = v
323
- else:
324
- device_map[f"{model_getter}{k}"] = v
325
- # For OPT models
326
- if "lm_head" not in device_map:
327
- try:
328
- device_map["lm_head"] = max(device_map.values())
329
- except TypeError:
330
- device_map["lm_head"] = "cpu"
331
- print("Device Map", device_map)
332
- dispatch_model(model, device_map=device_map)
333
- return
334
-
335
-
336
- class CrossModalEncoderModel(HuggingFaceModel):
337
- """CrossModalEncoderModel."""
338
-
339
- def __init__(
340
- self,
341
- model_name_or_path: str,
342
- model_type: Optional[str] = None,
343
- model_config: Optional[str] = None,
344
- cache_dir: Optional[str] = None,
345
- device: int = 0,
346
- use_accelerate: bool = False,
347
- use_parallelize: bool = False,
348
- use_bitsandbytes: bool = False,
349
- use_deepspeed: bool = False,
350
- perc_max_gpu_mem_red: float = 1.0,
351
- use_fp16: bool = False,
352
- ):
353
- """
354
- Initialize model.
355
-
356
- All arguments will be passed in the request from Manifest.
357
-
358
- Args:
359
- model_name_or_path: model name string.
360
- model_config: model config string.
361
- cache_dir: cache directory for model.
362
- device: device to use for model.
363
- use_accelerate: whether to use accelerate for multi-gpu inference.
364
- use_parallelize: use HF default parallelize
365
- use_bitsandbytes: use HF bits and bytes
366
- use_deepspeed: use deepspeed
367
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
368
- use_fp16: use fp16 for model weights.
369
- """
370
- super().__init__(
371
- model_name_or_path,
372
- model_type,
373
- model_config,
374
- cache_dir,
375
- device,
376
- use_accelerate,
377
- use_parallelize,
378
- use_bitsandbytes,
379
- use_deepspeed,
380
- perc_max_gpu_mem_red,
381
- use_fp16,
382
- )
383
-
384
- # TODO: make this generalizable
385
- self.processor = CLIPProcessor.from_pretrained(self.model_path)
386
-
387
- model = MODEL_REGISTRY.get(
388
- self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
389
- ).from_pretrained(
390
- self.model_path,
391
- cache_dir=cache_dir,
392
- trust_remote_code=True,
393
- )
394
- model.eval()
395
-
396
- torch_device = (
397
- torch.device("cpu")
398
- if (device == -1 or not torch.cuda.is_available())
399
- else torch.device(f"cuda:{device}")
400
- )
401
- self.model = model.to(torch_device) # type: ignore
402
-
403
- @torch.no_grad()
404
- def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
405
- """
406
- Compute embedding for prompts.
407
-
408
- Args:
409
- prompt: promt to generate from.
410
-
411
- Returns:
412
- embedding
413
- """
414
- if isinstance(prompt, str):
415
- inputs = self.processor(text=prompt, return_tensors="pt", padding=True)
416
- elif isinstance(prompt, PIL.Image.Image):
417
- inputs = self.processor(images=prompt, return_tensors="pt", padding=True)
418
- else:
419
- raise ValueError("Prompt must be a string or an image")
420
-
421
- outputs = self.model(**inputs)
422
- return outputs
423
-
424
-
425
- class TextGenerationModel(HuggingFaceModel):
426
- """Huggingface model."""
427
-
428
- def __init__(
429
- self,
430
- model_name_or_path: str,
431
- model_type: Optional[str] = None,
432
- model_config: Optional[str] = None,
433
- cache_dir: Optional[str] = None,
434
- device: int = 0,
435
- use_accelerate: bool = False,
436
- use_parallelize: bool = False,
437
- use_bitsandbytes: bool = False,
438
- use_deepspeed: bool = False,
439
- perc_max_gpu_mem_red: float = 1.0,
440
- use_fp16: bool = False,
441
- ):
442
- """
443
- Initialize model.
444
-
445
- All arguments will be passed in the request from Manifest.
446
-
447
- Args:
448
- model_name_or_path: model name string.
449
- model_config: model config string.
450
- cache_dir: cache directory for model.
451
- device: device to use for model.
452
- use_accelerate: whether to use accelerate for multi-gpu inference.
453
- use_parallelize: use HF default parallelize
454
- use_bitsandbytes: use HF bits and bytes
455
- use_deepspeed: use deepspeed
456
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
457
- use_fp16: use fp16 for model weights.
458
- """
459
- super().__init__(
460
- model_name_or_path,
461
- model_type,
462
- model_config,
463
- cache_dir,
464
- device,
465
- use_accelerate,
466
- use_parallelize,
467
- use_bitsandbytes,
468
- use_deepspeed,
469
- perc_max_gpu_mem_red,
470
- use_fp16,
471
- )
472
- if (
473
- MODEL_REGISTRY.get(
474
- self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
475
- )
476
- == LlamaForCausalLM
477
- ):
478
- tokenizer = LlamaTokenizer.from_pretrained(self.model_name)
479
- else:
480
- try:
481
- tokenizer = AutoTokenizer.from_pretrained(
482
- self.model_name, truncation_side="left", padding_side="left"
483
- )
484
- except ValueError:
485
- tokenizer = AutoTokenizer.from_pretrained(
486
- self.model_name,
487
- truncation_side="left",
488
- padding_side="left",
489
- use_fast=False,
490
- )
491
- dtype = torch.float16 if use_fp16 else "auto"
492
- if use_bitsandbytes:
493
- print("WARNING!!! Cannot use sampling with bitsandbytes.")
494
- max_memory = get_max_memory(perc_max_gpu_mem_red)
495
- model = MODEL_REGISTRY.get(
496
- self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
497
- ).from_pretrained( # type: ignore
498
- self.model_path,
499
- cache_dir=cache_dir,
500
- load_in_8bit=True,
501
- device_map="auto",
502
- max_memory=max_memory,
503
- trust_remote_code=True,
504
- )
505
- else:
506
- try:
507
- # Try to explicitely find a fp16 copy (gpt-j-6B for example)
508
- model = MODEL_REGISTRY.get(
509
- self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
510
- ).from_pretrained( # type: ignore
511
- self.model_path,
512
- cache_dir=cache_dir,
513
- revision="float16",
514
- torch_dtype=torch.float16,
515
- trust_remote_code=True,
516
- )
517
- except Exception:
518
- model = MODEL_REGISTRY.get(
519
- self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
520
- ).from_pretrained( # type: ignore
521
- self.model_path,
522
- cache_dir=cache_dir,
523
- torch_dtype=dtype,
524
- trust_remote_code=True,
525
- )
526
- model.eval()
527
- print(f"Loaded Model DType {model.dtype}")
528
- self.is_encdec = model.config.is_encoder_decoder
529
- if not self.is_encdec:
530
- tokenizer.pad_token = tokenizer.eos_token
531
- tokenizer.pad_token_id = tokenizer.eos_token_id
532
- if not use_bitsandbytes:
533
- if use_accelerate:
534
- self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
535
- device = 0
536
- elif use_parallelize:
537
- model.parallelize()
538
- device = 0
539
- elif use_deepspeed:
540
- self._dispatch_deepspeed_model(model)
541
- device = 0
542
- else:
543
- if device > -1:
544
- torch_device = (
545
- torch.device("cpu")
546
- if (device == -1 or not torch.cuda.is_available())
547
- else torch.device(f"cuda:{device}")
548
- )
549
- model = model.to(torch_device) # type: ignore
550
- self.pipeline = GenerationPipeline( # type: ignore
551
- model=model,
552
- tokenizer=tokenizer,
553
- device=device,
554
- bitsandbytes=use_bitsandbytes,
555
- is_encdec=self.is_encdec,
556
- )
557
-
558
- @torch.no_grad()
559
- def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
560
- """
561
- Embed the prompt from model.
562
-
563
- Args:
564
- prompt: promt to embed from.
565
-
566
- Returns:
567
- list of embeddings (list of length 1 for 1 embedding).
568
- """
569
- if isinstance(prompt, str):
570
- prompt = [prompt]
571
- encoded_prompt = self.pipeline.tokenizer(
572
- prompt,
573
- max_length=self.pipeline.max_length,
574
- truncation=True,
575
- padding=True,
576
- return_tensors="pt",
577
- )
578
- encoded_prompt = encoded_prompt.to(self.pipeline.device)
579
- # Get last hidden state
580
- output = self.pipeline.model( # type: ignore
581
- **encoded_prompt,
582
- output_hidden_states=True,
583
- return_dict=True,
584
- )
585
- last_hidden_state = output["hidden_states"][-1][:, -1, :]
586
- return last_hidden_state.cpu().numpy()
587
-
588
- @torch.no_grad()
589
- def generate(
590
- self, prompt: Union[str, List[str]], **kwargs: Any
591
- ) -> List[Tuple[Any, float, List[str], List[float]]]:
592
- """
593
- Generate the prompt from model.
594
-
595
- Outputs must be generated text and score, not including prompt.
596
-
597
- Args:
598
- prompt: promt to generate from.
599
-
600
- Returns:
601
- list of generated text (list of length 1 for 1 generation).
602
- """
603
- num_return = kwargs.get("n", 1)
604
- if isinstance(prompt, list) and num_return > 1:
605
- raise ValueError("In batch generate, n must be 1.")
606
- result = self.pipeline(
607
- prompt,
608
- max_new_tokens=kwargs.get("max_tokens"),
609
- temperature=kwargs.get("temperature"),
610
- repetition_penalty=kwargs.get("repetition_penalty"),
611
- top_k=kwargs.get("top_k"),
612
- top_p=kwargs.get("top_p"),
613
- do_sample=kwargs.get("do_sample"),
614
- num_return_sequences=num_return,
615
- )
616
- final_results = [
617
- (
618
- cast(str, r["generated_text"]),
619
- sum(cast(List[float], r["logprobs"])),
620
- cast(List[str], r["tokens"]),
621
- cast(List[float], r["logprobs"]),
622
- )
623
- for r in result
624
- ]
625
- return final_results
626
-
627
- @torch.no_grad()
628
- def score_sequence(
629
- self, prompt: Union[str, List[str]], **kwargs: Any
630
- ) -> List[Tuple[float, List[int], List[float]]]:
631
- """
632
- Score a sequence of choices.
633
-
634
- Args:
635
- prompt (:obj:`str` or :obj:`List[str]`):
636
- The prompt to score the choices against.
637
- **kwargs:
638
- Additional keyword arguments passed along to the :obj:`__call__` method.
639
- """
640
- if isinstance(prompt, str):
641
- prompt = [prompt]
642
- encoded_prompt = self.pipeline.tokenizer(
643
- prompt,
644
- max_length=self.pipeline.max_length,
645
- truncation=True,
646
- padding=True,
647
- return_tensors="pt",
648
- )
649
- encoded_prompt["labels"] = encoded_prompt["input_ids"].clone()
650
- encoded_prompt = encoded_prompt.to(self.pipeline.device)
651
- logits = self.pipeline.model( # type: ignore
652
- **encoded_prompt,
653
- ).logits
654
- # For causal decoders, shift logts and labels
655
- labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
656
- masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
657
- logits.float(), dim=-1
658
- )
659
- seq_token_log_probs = torch.gather(
660
- masked_log_probs, -1, encoded_prompt["labels"].unsqueeze(-1)
661
- )
662
- seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
663
- seq_log_prob = seq_token_log_probs.sum(dim=-1)
664
- return [
665
- (seq, tokens, seq_token)
666
- for seq, tokens, seq_token in zip(
667
- seq_log_prob.tolist(),
668
- encoded_prompt["input_ids"].tolist(),
669
- seq_token_log_probs.tolist(),
670
- )
671
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/models/model.py DELETED
@@ -1,91 +0,0 @@
1
- """Model class."""
2
- from typing import Any, Dict, List, Tuple, Union
3
-
4
- import numpy as np
5
-
6
-
7
- class Model:
8
- """Model class."""
9
-
10
- def __init__(
11
- self,
12
- model_name_or_path: str,
13
- model_type: str,
14
- cache_dir: str,
15
- device: int,
16
- use_accelerate: bool,
17
- use_parallelize: bool,
18
- use_bitsandbytes: bool,
19
- use_deepspeed: bool,
20
- perc_max_gpu_mem_red: float,
21
- use_fp16: bool,
22
- ):
23
- """
24
- Initialize model.
25
-
26
- All arguments will be passed in the request from Manifest.
27
-
28
- Args:
29
- model_name_or_path: model name string.
30
- model_type: model type string for when model_name not in registry.
31
- cache_dir: cache directory for model.
32
- device: device to use for model.
33
- use_accelerate: whether to use accelerate for multi-gpu inference.
34
- use_parallelize: use HF default parallelize
35
- use_bitsandbytes: use HF bits and bytes
36
- use_deepspeed: use deepspeed
37
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
38
- use_fp16: use fp16 for model weights.
39
- """
40
- raise NotImplementedError()
41
-
42
- def get_init_params(self) -> Dict:
43
- """Return init params to determine what model is being used."""
44
- raise NotImplementedError()
45
-
46
- def generate(
47
- self, prompt: Union[str, List[str]], **kwargs: Any
48
- ) -> List[Tuple[Any, float, List[str], List[float]]]:
49
- """
50
- Generate the prompt from model.
51
-
52
- Outputs must be generated text and score, not including prompt.
53
-
54
- Args:
55
- prompt: promt to generate from.
56
-
57
- Returns:
58
- list of generated text (list of length 1 for 1 generation).
59
- Each item is the response, answer logprob, list of tokens,
60
- and list of logprobs for each token.
61
- """
62
- raise NotImplementedError()
63
-
64
- def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
65
- """
66
- Embed the prompt from model.
67
-
68
- Args:
69
- prompt: promt to embed from.
70
-
71
- Returns:
72
- list of embeddings (list of length 1 for 1 embedding).
73
- """
74
- raise NotImplementedError()
75
-
76
- def score_sequence(
77
- self, prompt: Union[str, List[str]], **kwargs: Any
78
- ) -> List[Tuple[float, List[int], List[float]]]:
79
- """
80
- Score a sequence of choices.
81
-
82
- Args:
83
- prompt (:obj:`str` or :obj:`List[str]`):
84
- The prompt to score the choices against.
85
- **kwargs:
86
- Additional keyword arguments passed along to the :obj:`__call__` method.
87
-
88
- Returns:
89
- Tuple of total score, tokens, and probs per token.
90
- """
91
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/models/sentence_transformer.py DELETED
@@ -1,113 +0,0 @@
1
- """Sentence transformer model."""
2
- from typing import Any, Dict, List, Optional, Tuple, Union
3
-
4
- import numpy as np
5
- import torch
6
- from sentence_transformers import SentenceTransformer
7
-
8
- from manifest.api.models.model import Model
9
-
10
-
11
- class SentenceTransformerModel(Model):
12
- """SentenceTransformer model."""
13
-
14
- def __init__(
15
- self,
16
- model_name_or_path: str,
17
- model_type: Optional[str] = None,
18
- model_config: Optional[str] = None,
19
- cache_dir: Optional[str] = None,
20
- device: int = 0,
21
- use_accelerate: bool = False,
22
- use_parallelize: bool = False,
23
- use_bitsandbytes: bool = False,
24
- use_deepspeed: bool = False,
25
- perc_max_gpu_mem_red: float = 1.0,
26
- use_fp16: bool = False,
27
- ):
28
- """
29
- Initialize model.
30
-
31
- All arguments will be passed in the request from Manifest.
32
-
33
- Args:
34
- model_name_or_path: model name string.
35
- model_config: model config string.
36
- cache_dir: cache directory for model.
37
- device: device to use for model.
38
- use_accelerate: whether to use accelerate for multi-gpu inference.
39
- use_parallelize: use HF default parallelize
40
- use_bitsandbytes: use HF bits and bytes
41
- use_deepspeed: use deepspeed
42
- perc_max_gpu_mem_red: percent max memory reduction in accelerate
43
- use_fp16: use fp16 for model weights.
44
- """
45
- if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
46
- raise ValueError(
47
- "Cannot use accelerate or parallelize or "
48
- "bitsandbytes or deepspeeed with sentence transformers"
49
- )
50
- # Check if providing path
51
- self.model_name = model_name_or_path
52
- print("Model Name:", self.model_name)
53
- torch_device = (
54
- torch.device("cpu")
55
- if (device == -1 or not torch.cuda.is_available())
56
- else torch.device(f"cuda:{device}")
57
- )
58
- self.embedding_model = SentenceTransformer(self.model_name, device=torch_device)
59
- self.embedding_model.to(torch_device)
60
- self.embedding_model.eval()
61
-
62
- def get_init_params(self) -> Dict:
63
- """Return init params to determine what model is being used."""
64
- return {"model_name": self.model_name, "model_path": self.model_name}
65
-
66
- @torch.no_grad()
67
- def generate(
68
- self, prompt: Union[str, List[str]], **kwargs: Any
69
- ) -> List[Tuple[Any, float, List[str], List[float]]]:
70
- """
71
- Generate the prompt from model.
72
-
73
- Outputs must be generated text and score, not including prompt.
74
-
75
- Args:
76
- prompt: promt to generate from.
77
-
78
- Returns:
79
- list of generated text (list of length 1 for 1 generation).
80
- """
81
- raise NotImplementedError("Generate not supported for sentence transformers")
82
-
83
- @torch.no_grad()
84
- def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
85
- """
86
- Embed the prompt from model.
87
-
88
- Args:
89
- prompt: promt to embed from.
90
-
91
- Returns:
92
- list of embeddings (list of length 1 for 1 embedding).
93
- """
94
- if isinstance(prompt, str):
95
- prompt = [prompt]
96
- return self.embedding_model.encode(prompt)
97
-
98
- @torch.no_grad()
99
- def score_sequence(
100
- self, prompt: Union[str, List[str]], **kwargs: Any
101
- ) -> List[Tuple[float, List[int], List[float]]]:
102
- """
103
- Score a sequence of choices.
104
-
105
- Args:
106
- prompt (:obj:`str` or :obj:`List[str]`):
107
- The prompt to score the choices against.
108
- **kwargs:
109
- Additional keyword arguments passed along to the :obj:`__call__` method.
110
- """
111
- raise NotImplementedError(
112
- "Score sequence not supported for sentence transformers"
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/api/response.py DELETED
@@ -1,55 +0,0 @@
1
- """Response."""
2
-
3
- import time
4
- import uuid
5
- from typing import Any, Dict, List
6
-
7
-
8
- class ModelResponse:
9
- """ModelResponse."""
10
-
11
- def __init__(self, results: List[Dict[str, Any]], response_type: str) -> None:
12
- """Initialize response."""
13
- self.results = results
14
- self.response_type = response_type
15
- if self.response_type not in {
16
- "text_completion",
17
- "prompt_logit_score",
18
- "image_generation",
19
- "embedding_generation",
20
- }:
21
- raise ValueError(
22
- f"Invalid response type: {self.response_type}. "
23
- "Must be one of: text_completion, prompt_logit_score, "
24
- "image_generation, embedding_generation."
25
- )
26
- self.response_id = str(uuid.uuid4())
27
- self.created = int(time.time())
28
-
29
- def __dict__(self) -> Dict[str, Any]: # type: ignore
30
- """Return dictionary representation of response."""
31
- key = (
32
- "text"
33
- if self.response_type not in {"image_generation", "embedding_generation"}
34
- else "array"
35
- )
36
- return {
37
- "id": self.response_id,
38
- "object": self.response_type,
39
- "created": self.created,
40
- "model": "flask_model",
41
- "choices": [
42
- {
43
- key: result[key],
44
- "logprob": result["logprob"],
45
- "tokens": result["tokens"],
46
- "token_logprobs": result["token_logprobs"],
47
- }
48
- if key == "text"
49
- else {
50
- key: result[key].tolist(),
51
- "logprob": result["logprob"],
52
- }
53
- for result in self.results
54
- ],
55
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Cache init."""
 
 
duckdb-nsql/manifest/manifest/caches/array_cache.py DELETED
@@ -1,116 +0,0 @@
1
- """Array cache."""
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- import numpy as np
6
- from sqlitedict import SqliteDict
7
-
8
-
9
- def open_mmap_arr(file: Union[Path, str], size: float) -> np.memmap:
10
- """Open memmap."""
11
- if not Path(file).exists():
12
- mode = "w+"
13
- else:
14
- mode = "r+"
15
- arr = np.memmap( # type: ignore
16
- str(file),
17
- dtype=np.float32, # This means we only support float 32
18
- mode=mode,
19
- shape=size,
20
- )
21
- return arr
22
-
23
-
24
- class ArrayCache:
25
- """Array cache."""
26
-
27
- def __init__(self, folder: Union[str, Path]) -> None:
28
- """
29
- Initialize the array writer.
30
-
31
- Args:
32
- folder: folder to write to.
33
- """
34
- self.folder = Path(folder)
35
- self.folder.mkdir(exist_ok=True, parents=True)
36
- self.hash2arrloc = SqliteDict(
37
- self.folder / "hash2arrloc.sqlite", autocommit=True
38
- )
39
- # Approx 1GB (I think)
40
- self.max_memmap_size = 20480000
41
- self.cur_file_idx = 0
42
- # Get the last file idx used
43
- for key in self.hash2arrloc:
44
- file_data = self.hash2arrloc[key]
45
- if file_data["file_idx"] > self.cur_file_idx:
46
- self.cur_file_idx = file_data["file_idx"]
47
- self.cur_memmap = open_mmap_arr(
48
- self.folder / f"{self.cur_file_idx}.npy",
49
- self.max_memmap_size,
50
- )
51
- # Make sure there is space left in the memmap
52
- non_zero = np.nonzero(self.cur_memmap)[0]
53
- if len(non_zero) > 0:
54
- self.cur_offset = int(np.max(non_zero) + 1)
55
- else:
56
- self.cur_offset = 0
57
- # If no space, make a new memmap
58
- if self.cur_offset == self.max_memmap_size:
59
- self.cur_file_idx += 1
60
- self.cur_memmap = open_mmap_arr(
61
- self.folder / f"{self.cur_file_idx}.npy",
62
- self.max_memmap_size,
63
- )
64
- self.cur_offset = 0
65
-
66
- def contains_key(self, key: str) -> bool:
67
- """
68
- Check if the key is in the cache.
69
-
70
- Args:
71
- key: key to check.
72
-
73
- Returns:
74
- True if the key is in the cache.
75
- """
76
- return key in self.hash2arrloc
77
-
78
- def put(self, key: str, arr: np.ndarray) -> None:
79
- """Save array in store and associate location with key."""
80
- # Check if there is space in the memmap
81
- arr_shape = arr.shape
82
- arr = arr.flatten()
83
- if len(arr) > self.max_memmap_size:
84
- raise ValueError(
85
- f"Array is too large to be cached. Max is {self.max_memmap_size}"
86
- )
87
- if self.cur_offset + len(arr) > self.max_memmap_size:
88
- self.cur_file_idx += 1
89
- self.cur_memmap = open_mmap_arr(
90
- self.folder / f"{self.cur_file_idx}.npy",
91
- self.max_memmap_size,
92
- )
93
- self.cur_offset = 0
94
- self.cur_memmap[self.cur_offset : self.cur_offset + len(arr)] = arr
95
- self.cur_memmap.flush()
96
- self.hash2arrloc[key] = {
97
- "file_idx": self.cur_file_idx,
98
- "offset": self.cur_offset,
99
- "flatten_size": len(arr),
100
- "shape": arr_shape,
101
- "dtype": arr.dtype,
102
- }
103
- self.cur_offset += len(arr)
104
- return
105
-
106
- def get(self, key: str) -> np.ndarray:
107
- """Get array associated with location from key."""
108
- file_data = self.hash2arrloc[key]
109
- memmap = open_mmap_arr(
110
- self.folder / f"{file_data['file_idx']}.npy",
111
- self.max_memmap_size,
112
- )
113
- arr = memmap[
114
- file_data["offset"] : file_data["offset"] + file_data["flatten_size"]
115
- ]
116
- return arr.reshape(file_data["shape"]).astype(file_data["dtype"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/cache.py DELETED
@@ -1,135 +0,0 @@
1
- """Cache for queries and responses."""
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Dict, Type, Union
4
-
5
- from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
6
- from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
7
- from manifest.response import Response
8
-
9
- # Non-text return type caches
10
- ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
11
-
12
-
13
- class Cache(ABC):
14
- """A cache for request/response pairs."""
15
-
16
- def __init__(
17
- self,
18
- connection_str: str,
19
- request_type: Type[Request] = LMRequest,
20
- cache_args: Dict[str, Any] = {},
21
- ):
22
- """
23
- Initialize cache.
24
-
25
- Args:
26
- connection_str: connection string.
27
- request_type: request type.
28
- cache_args: arguments for cache.
29
-
30
- cache_args are any arguments needed to initialize the cache.
31
-
32
- Further, cache_args can contain `array_serializer` as a string
33
- for embedding or image return types (e.g. diffusers) with values
34
- as `local_file` or `byte_string`. `local_file` will save the
35
- array in a local file and cache a pointer to the file.
36
- `byte_string` will convert the array to a byte string and cache
37
- the entire byte string. `byte_string` is default.
38
-
39
- Args:
40
- connection_str: connection string for cache.
41
- cache_args: cache arguments.
42
- """
43
- self.request_type = request_type
44
- self.connect(connection_str, cache_args)
45
- if self.request_type in ARRAY_CACHE_TYPES:
46
- array_serializer = cache_args.pop("array_serializer", "byte_string")
47
- if array_serializer not in ["local_file", "byte_string"]:
48
- raise ValueError(
49
- "array_serializer must be local_file or byte_string,"
50
- f" not {array_serializer}"
51
- )
52
- self.serializer = (
53
- ArraySerializer()
54
- if array_serializer == "local_file"
55
- else NumpyByteSerializer()
56
- )
57
- else:
58
- # If user has array_serializer type, it will throw an error as
59
- # it is not recognized for non-array return types.
60
- self.serializer = Serializer()
61
-
62
- @abstractmethod
63
- def close(self) -> None:
64
- """Close the cache."""
65
- raise NotImplementedError()
66
-
67
- @abstractmethod
68
- def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
69
- """
70
- Connect to cache.
71
-
72
- Args:
73
- connection_str: connection string.
74
- """
75
- raise NotImplementedError()
76
-
77
- @abstractmethod
78
- def get_key(self, key: str, table: str = "default") -> Union[str, None]:
79
- """
80
- Get the key for a request.
81
-
82
- With return None if key is not in cache.
83
-
84
- Args:
85
- key: key for cache.
86
- table: table to get key in.
87
- """
88
- raise NotImplementedError()
89
-
90
- @abstractmethod
91
- def set_key(self, key: str, value: str, table: str = "default") -> None:
92
- """
93
- Set the value for the key.
94
-
95
- Will override old value.
96
-
97
- Args:
98
- key: key for cache.
99
- value: new value for key.
100
- table: table to set key in.
101
- """
102
- raise NotImplementedError()
103
-
104
- @abstractmethod
105
- def commit(self) -> None:
106
- """Commit any results."""
107
- raise NotImplementedError()
108
-
109
- def get(self, request: Dict) -> Union[Response, None]:
110
- """Get the result of request (by calling compute as needed).
111
-
112
- Args:
113
- request: request to get.
114
- response: response to get.
115
-
116
- Returns:
117
- Response object or None if not in cache.
118
- """
119
- key = self.serializer.request_to_key(request)
120
- cached_response = self.get_key(key)
121
- if cached_response:
122
- response = self.serializer.key_to_response(cached_response)
123
- response["cached"] = True
124
- return Response.from_dict(response, request_dict=request)
125
- return None
126
-
127
- def set(self, request: Dict, response: Dict) -> None:
128
- """Set the value for the key.
129
-
130
- Args:
131
- request: request to set.
132
- response: response to set.
133
- """
134
- key = self.serializer.request_to_key(request)
135
- self.set_key(key, self.serializer.response_to_key(response))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/noop.py DELETED
@@ -1,47 +0,0 @@
1
- """Noop cache."""
2
- from typing import Any, Dict, Union
3
-
4
- from manifest.caches.cache import Cache
5
-
6
-
7
- class NoopCache(Cache):
8
- """A Noop cache that caches nothing for request/response pairs."""
9
-
10
- def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
11
- """
12
- Connect to client.
13
-
14
- Args:
15
- connection_str: connection string.
16
- cache_args: arguments for cache.
17
- """
18
- pass
19
-
20
- def close(self) -> None:
21
- """Close the client."""
22
- pass
23
-
24
- def get_key(self, key: str, table: str = "default") -> Union[str, None]:
25
- """
26
- Return None key for never in cache.
27
-
28
- Args:
29
- key: key for cache.
30
- table: table to get key in.
31
- """
32
- return None
33
-
34
- def set_key(self, key: str, value: str, table: str = "default") -> None:
35
- """
36
- Do not set anything as no cache.
37
-
38
- Args:
39
- key: key for cache.
40
- value: new value for key.
41
- table: table to set key in.
42
- """
43
- pass
44
-
45
- def commit(self) -> None:
46
- """Commit any results."""
47
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/postgres.py DELETED
@@ -1,131 +0,0 @@
1
- """Postgres cache."""
2
- import hashlib
3
- import logging
4
- from typing import Any, Dict, Union
5
-
6
- logger = logging.getLogger("postgresql")
7
- logger.setLevel(logging.WARNING)
8
-
9
- from ..caches.cache import Cache
10
-
11
- try:
12
- import sqlalchemy # type: ignore
13
- from google.cloud.sql.connector import Connector # type: ignore
14
- from sqlalchemy import Column, String # type: ignore
15
- from sqlalchemy.ext.declarative import declarative_base # type: ignore
16
- from sqlalchemy.orm import sessionmaker # type: ignore
17
-
18
- Base = declarative_base()
19
-
20
- class Request(Base): # type: ignore
21
- """The request table."""
22
-
23
- __tablename__ = "requests"
24
- key = Column(String, primary_key=True)
25
- response = Column(
26
- String
27
- ) # FIXME: ideally should be an hstore, but I don't want to set it up on GCP
28
-
29
- missing_dependencies = None
30
-
31
- except ImportError as e:
32
- missing_dependencies = e
33
-
34
-
35
- class PostgresCache(Cache):
36
- """A PostgreSQL cache for request/response pairs."""
37
-
38
- def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
39
- """
40
- Connect to client.
41
-
42
- Args:
43
- connection_str: connection string.
44
- cache_args: arguments for cache should include the following fields:
45
- {
46
- "cache_user": "",
47
- "cache_password": "",
48
- "cache_db": ""
49
- }
50
- """
51
- if missing_dependencies:
52
- raise ValueError(
53
- "Missing dependencies for GCP PostgreSQL cache. "
54
- "Install with `pip install manifest[gcp]`",
55
- missing_dependencies,
56
- )
57
-
58
- connector = Connector()
59
-
60
- def getconn() -> Any:
61
- conn = connector.connect(
62
- connection_str,
63
- "pg8000",
64
- user=cache_args.pop("cache_user"),
65
- password=cache_args.pop("cache_password"),
66
- db=cache_args.pop("cache_db"),
67
- )
68
- return conn
69
-
70
- engine = sqlalchemy.create_engine(
71
- "postgresql+pg8000://",
72
- creator=getconn,
73
- )
74
- engine.dialect.description_encoding = None # type: ignore
75
-
76
- db_exists = len(sqlalchemy.inspect(engine).get_table_names()) > 0
77
- if not db_exists:
78
- logger.info("Creating database...")
79
- Base.metadata.create_all(engine)
80
-
81
- self.session = sessionmaker(bind=engine)()
82
-
83
- def close(self) -> None:
84
- """Close the client."""
85
- self.session.close()
86
-
87
- @staticmethod
88
- def _hash_key(key: str, table: str) -> str:
89
- """Compute MD5 hash of the key."""
90
- return hashlib.md5(f"{key}:{table}".encode("utf-8")).hexdigest()
91
-
92
- def get_key(self, key: str, table: str = "default") -> Union[str, None]:
93
- """
94
- Get the key for a request.
95
-
96
- With return None if key is not in cache.
97
-
98
- Args:
99
- key: key for cache.
100
- table: table to get key in.
101
- """
102
- request = (
103
- self.session.query(Request) # type: ignore
104
- .filter_by(key=self._hash_key(key, table))
105
- .first()
106
- )
107
- out = request.response if request else None
108
- return out # type: ignore
109
-
110
- def set_key(self, key: str, value: str, table: str = "default") -> None:
111
- """
112
- Set the value for the key.
113
-
114
- Will override old value.
115
-
116
- Args:
117
- key: key for cache.
118
- value: new value for key.
119
- table: table to set key in.
120
- """
121
- key = self._hash_key(key, table)
122
- request = self.session.query(Request).filter_by(key=key).first() # type: ignore
123
- if request:
124
- request.response = value # type: ignore
125
- else:
126
- self.session.add(Request(key=key, response=value))
127
- self.commit()
128
-
129
- def commit(self) -> None:
130
- """Commit any results."""
131
- self.session.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/redis.py DELETED
@@ -1,64 +0,0 @@
1
- """Redis cache."""
2
- from typing import Any, Dict, Union
3
-
4
- import redis
5
-
6
- from manifest.caches.cache import Cache
7
-
8
-
9
- class RedisCache(Cache):
10
- """A Redis cache for request/response pairs."""
11
-
12
- def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
13
- """
14
- Connect to client.
15
-
16
- Args:
17
- connection_str: connection string.
18
- cache_args: arguments for cache.
19
- """
20
- host, port = connection_str.split(":")
21
- self.redis = redis.Redis(host=host, port=int(port), db=0)
22
- return
23
-
24
- def close(self) -> None:
25
- """Close the client."""
26
- self.redis.close()
27
-
28
- def _normalize_table_key(self, key: str, table: str) -> str:
29
- """Cast key for prompt key."""
30
- return f"{table}:{key}"
31
-
32
- def get_key(self, key: str, table: str = "default") -> Union[str, None]:
33
- """
34
- Get the key for a request.
35
-
36
- With return None if key is not in cache.
37
-
38
- Args:
39
- key: key for cache.
40
- table: table to get key in.
41
- """
42
- norm_key = self._normalize_table_key(key, table)
43
- if self.redis.exists(norm_key):
44
- return self.redis.get(norm_key).decode("utf-8")
45
- else:
46
- return None
47
-
48
- def set_key(self, key: str, value: str, table: str = "default") -> None:
49
- """
50
- Set the value for the key.
51
-
52
- Will override old value.
53
-
54
- Args:
55
- key: key for cache.
56
- value: new value for key.
57
- table: table to set key in.
58
- """
59
- self.redis.set(self._normalize_table_key(key, table), value)
60
- self.commit()
61
-
62
- def commit(self) -> None:
63
- """Commit any results."""
64
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/serializers.py DELETED
@@ -1,204 +0,0 @@
1
- """Serializer."""
2
-
3
- import io
4
- import json
5
- import os
6
- from pathlib import Path
7
- from typing import Dict
8
-
9
- import numpy as np
10
- import xxhash
11
-
12
- from manifest.caches.array_cache import ArrayCache
13
-
14
-
15
- class Serializer:
16
- """Serializer."""
17
-
18
- def request_to_key(self, request: Dict) -> str:
19
- """
20
- Normalize a request into a key.
21
-
22
- Args:
23
- request: request to normalize.
24
-
25
- Returns:
26
- normalized key.
27
- """
28
- return json.dumps(request, sort_keys=True)
29
-
30
- def key_to_request(self, key: str) -> Dict:
31
- """
32
- Convert the normalized version to the request.
33
-
34
- Args:
35
- key: normalized key to convert.
36
-
37
- Returns:
38
- unnormalized request dict.
39
- """
40
- return json.loads(key)
41
-
42
- def response_to_key(self, response: Dict) -> str:
43
- """
44
- Normalize a response into a key.
45
-
46
- Args:
47
- response: response to normalize.
48
-
49
- Returns:
50
- normalized key.
51
- """
52
- return json.dumps(response, sort_keys=True)
53
-
54
- def key_to_response(self, key: str) -> Dict:
55
- """
56
- Convert the normalized version to the response.
57
-
58
- Args:
59
- key: normalized key to convert.
60
-
61
- Returns:
62
- unnormalized response dict.
63
- """
64
- return json.loads(key)
65
-
66
-
67
- class NumpyByteSerializer(Serializer):
68
- """Serializer by casting array to byte string."""
69
-
70
- def response_to_key(self, response: Dict) -> str:
71
- """
72
- Normalize a response into a key.
73
-
74
- Args:
75
- response: response to normalize.
76
-
77
- Returns:
78
- normalized key.
79
- """
80
- sub_response = response["response"]
81
- # Assume response is a dict with keys "choices" -> List dicts
82
- # with keys "array".
83
- choices = sub_response["choices"]
84
- # We don't want to modify the response in place
85
- # but we want to avoid calling deepcopy on an array
86
- del sub_response["choices"]
87
- response_copy = sub_response.copy()
88
- sub_response["choices"] = choices
89
- response_copy["choices"] = []
90
- for choice in choices:
91
- if "array" not in choice:
92
- raise ValueError(
93
- f"Choice with keys {choice.keys()} does not have array key."
94
- )
95
- arr = choice["array"]
96
- # Avoid copying an array
97
- del choice["array"]
98
- new_choice = choice.copy()
99
- choice["array"] = arr
100
- with io.BytesIO() as f:
101
- np.savez_compressed(f, data=arr)
102
- hash_str = f.getvalue().hex()
103
- new_choice["array"] = hash_str
104
- response_copy["choices"].append(new_choice)
105
- response["response"] = response_copy
106
- return json.dumps(response, sort_keys=True)
107
-
108
- def key_to_response(self, key: str) -> Dict:
109
- """
110
- Convert the normalized version to the response.
111
-
112
- Args:
113
- key: normalized key to convert.
114
-
115
- Returns:
116
- unnormalized response dict.
117
- """
118
- response = json.loads(key)
119
- for choice in response["response"]["choices"]:
120
- hash_str = choice["array"]
121
- byte_str = bytes.fromhex(hash_str)
122
- with io.BytesIO(byte_str) as f:
123
- choice["array"] = np.load(f)["data"]
124
- return response
125
-
126
-
127
- class ArraySerializer(Serializer):
128
- """Serializer for array."""
129
-
130
- def __init__(self) -> None:
131
- """
132
- Initialize array serializer.
133
-
134
- We don't want to cache the array. We hash the value and
135
- store the array in a memmap file. Store filename/offsets
136
- in sqlitedict to keep track of hash -> array.
137
- """
138
- super().__init__()
139
-
140
- self.hash = xxhash.xxh64()
141
- manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
142
- cache_folder = manifest_home / ".manifest" / "array_cache"
143
- self.writer = ArrayCache(cache_folder)
144
-
145
- def response_to_key(self, response: Dict) -> str:
146
- """
147
- Normalize a response into a key.
148
-
149
- Convert arrays to hash string for cache key.
150
-
151
- Args:
152
- response: response to normalize.
153
-
154
- Returns:
155
- normalized key.
156
- """
157
- sub_response = response["response"]
158
- # Assume response is a dict with keys "choices" -> List dicts
159
- # with keys "array".
160
- choices = sub_response["choices"]
161
- # We don't want to modify the response in place
162
- # but we want to avoid calling deepcopy on an array
163
- del sub_response["choices"]
164
- response_copy = sub_response.copy()
165
- sub_response["choices"] = choices
166
- response_copy["choices"] = []
167
- for choice in choices:
168
- if "array" not in choice:
169
- raise ValueError(
170
- f"Choice with keys {choice.keys()} does not have array key."
171
- )
172
- arr = choice["array"]
173
- # Avoid copying an array
174
- del choice["array"]
175
- new_choice = choice.copy()
176
- choice["array"] = arr
177
-
178
- self.hash.update(arr)
179
- hash_str = self.hash.hexdigest()
180
- self.hash.reset()
181
- new_choice["array"] = hash_str
182
- response_copy["choices"].append(new_choice)
183
- if not self.writer.contains_key(hash_str):
184
- self.writer.put(hash_str, arr)
185
- response["response"] = response_copy
186
- return json.dumps(response, sort_keys=True)
187
-
188
- def key_to_response(self, key: str) -> Dict:
189
- """
190
- Convert the normalized version to the response.
191
-
192
- Convert the hash string keys to the arrays.
193
-
194
- Args:
195
- key: normalized key to convert.
196
-
197
- Returns:
198
- unnormalized response dict.
199
- """
200
- response = json.loads(key)
201
- for choice in response["response"]["choices"]:
202
- hash_str = choice["array"]
203
- choice["array"] = self.writer.get(hash_str)
204
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/caches/sqlite.py DELETED
@@ -1,65 +0,0 @@
1
- """SQLite cache."""
2
- import logging
3
- from typing import Any, Dict, Union
4
-
5
- from sqlitedict import SqliteDict
6
-
7
- from manifest.caches.cache import Cache
8
-
9
- logging.getLogger("sqlitedict").setLevel(logging.WARNING)
10
-
11
-
12
- class SQLiteCache(Cache):
13
- """A SQLite cache for request/response pairs."""
14
-
15
- def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
16
- """
17
- Connect to client.
18
-
19
- Args:
20
- connection_str: connection string.
21
- cache_args: arguments for cache.
22
- """
23
- self.cache_file = connection_str
24
- if not self.cache_file:
25
- self.cache_file = ".sqlite.cache"
26
- self.cache = SqliteDict(self.cache_file, autocommit=True)
27
- return
28
-
29
- def close(self) -> None:
30
- """Close the client."""
31
- self.cache.close()
32
-
33
- def _normalize_table_key(self, key: str, table: str) -> str:
34
- """Cast key for prompt key."""
35
- return f"{table}:{key}"
36
-
37
- def get_key(self, key: str, table: str = "default") -> Union[str, None]:
38
- """
39
- Get the key for a request.
40
-
41
- With return None if key is not in cache.
42
-
43
- Args:
44
- key: key for cache.
45
- table: table to get key in.
46
- """
47
- return self.cache.get(self._normalize_table_key(key, table))
48
-
49
- def set_key(self, key: str, value: str, table: str = "default") -> None:
50
- """
51
- Set the value for the key.
52
-
53
- Will override old value.
54
-
55
- Args:
56
- key: key for cache.
57
- value: new value for key.
58
- table: table to set key in.
59
- """
60
- self.cache[self._normalize_table_key(key, table)] = value
61
- self.commit()
62
-
63
- def commit(self) -> None:
64
- """Commit any results."""
65
- self.cache.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Client init."""
 
 
duckdb-nsql/manifest/manifest/clients/ai21.py DELETED
@@ -1,125 +0,0 @@
1
- """AI21 client."""
2
- import logging
3
- import os
4
- from typing import Any, Dict, Optional
5
-
6
- from manifest.clients.client import Client
7
- from manifest.request import LMRequest
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- AI21_ENGINES = {
12
- "j2-ultra",
13
- "j2-mid",
14
- "j2-light",
15
- }
16
-
17
-
18
- class AI21Client(Client):
19
- """AI21Client client."""
20
-
21
- # User param -> (client param, default value)
22
- PARAMS = {
23
- "engine": ("engine", "j2-ultra"),
24
- "temperature": ("temperature", 0.7),
25
- "max_tokens": ("maxTokens", 40),
26
- "top_k": ("topKReturn", 0),
27
- "n": ("numResults", 1),
28
- "top_p": ("topP", 1.0),
29
- "stop_sequences": ("stopSequences", []),
30
- }
31
- REQUEST_CLS = LMRequest
32
- NAME = "ai21"
33
-
34
- def connect(
35
- self,
36
- connection_str: Optional[str] = None,
37
- client_args: Dict[str, Any] = {},
38
- ) -> None:
39
- """
40
- Connect to the AI21 server.
41
-
42
- connection_str is passed as default AI21_API_KEY if variable not set.
43
-
44
- Args:
45
- connection_str: connection string.
46
- client_args: client arguments.
47
- """
48
- # Taken from https://docs.ai21.com/
49
- self.host = "https://api.ai21.com/studio/v1"
50
- self.api_key = connection_str or os.environ.get("AI21_API_KEY")
51
- if self.api_key is None:
52
- raise ValueError(
53
- "AI21 API key not set. Set AI21_API_KEY environment "
54
- "variable or pass through `client_connection`."
55
- )
56
-
57
- for key in self.PARAMS:
58
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
59
- if getattr(self, "engine") not in AI21_ENGINES:
60
- raise ValueError(
61
- f"Invalid engine {getattr(self, 'engine')}. Must be {AI21_ENGINES}."
62
- )
63
-
64
- def close(self) -> None:
65
- """Close the client."""
66
- pass
67
-
68
- def get_generation_url(self) -> str:
69
- """Get generation URL."""
70
- return self.host + "/" + getattr(self, "engine") + "/complete"
71
-
72
- def get_generation_header(self) -> Dict[str, str]:
73
- """
74
- Get generation header.
75
-
76
- Returns:
77
- header.
78
- """
79
- return {"Authorization": f"Bearer {self.api_key}"}
80
-
81
- def supports_batch_inference(self) -> bool:
82
- """Return whether the client supports batch inference."""
83
- return False
84
-
85
- def supports_streaming_inference(self) -> bool:
86
- """Return whether the client supports streaming inference.
87
-
88
- Override in child client class.
89
- """
90
- return False
91
-
92
- def get_model_params(self) -> Dict:
93
- """
94
- Get model params.
95
-
96
- By getting model params from the server, we can add to request
97
- and make sure cache keys are unique to model.
98
-
99
- Returns:
100
- model params.
101
- """
102
- return {"model_name": self.NAME, "engine": getattr(self, "engine")}
103
-
104
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
105
- """
106
- Format response to dict.
107
-
108
- Args:
109
- response: response
110
- request: request
111
-
112
- Return:
113
- response as dict
114
- """
115
- return {
116
- "object": "text_completion",
117
- "model": getattr(self, "engine"),
118
- "choices": [
119
- {
120
- "text": item["data"]["text"],
121
- "token_logprobs": item["data"]["tokens"],
122
- }
123
- for item in response["completions"]
124
- ],
125
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/azureendpoint.py DELETED
@@ -1,139 +0,0 @@
1
- """OpenRouter client."""
2
-
3
- import copy
4
- import logging
5
- import os
6
- from typing import Any, Dict, Optional
7
- import time
8
- from manifest.clients.client import Client
9
- from manifest.request import LMRequest
10
- import urllib.request
11
- import json
12
- import os
13
- import ssl
14
-
15
- logger = logging.getLogger(__name__)
16
- def allowSelfSignedHttps(allowed):
17
- # bypass the server certificate verification on client side
18
- if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None):
19
- ssl._create_default_https_context = ssl._create_unverified_context
20
-
21
- allowSelfSignedHttps(True) # this line is needed if you use self-signed certificate in your scoring service.
22
-
23
-
24
- class AzureEndpointClient(Client):
25
- """OpenRouter client."""
26
-
27
- # Params are defined in https://openrouter.ai/docs/parameters
28
- PARAMS = {
29
- "engine": ("model", "meta-llama/codellama-70b-instruct"),
30
- "max_tokens": ("max_tokens", 1000),
31
- "temperature": ("temperature", 0.1),
32
- "top_k": ("k", 0),
33
- "frequency_penalty": ("frequency_penalty", 0.0),
34
- "presence_penalty": ("presence_penalty", 0.0),
35
- "stop_sequences": ("stop", None),
36
- }
37
- REQUEST_CLS = LMRequest
38
- NAME = "azureendpoint"
39
- IS_CHAT = True
40
-
41
- def connect(
42
- self,
43
- connection_str: Optional[str] = None,
44
- client_args: Dict[str, Any] = {},
45
- ) -> None:
46
- """
47
- Connect to the OpenRouter server.
48
-
49
- connection_str is passed as default OPENROUTER_API_KEY if variable not set.
50
-
51
- Args:
52
- connection_str: connection string.
53
- client_args: client arguments.
54
- """
55
-
56
- self.host = os.environ.get("AZURE_HOST")
57
- # Replace this with the primary/secondary key, AMLToken, or Microsoft Entra ID token for the endpoint
58
- self.api_key = os.environ.get("AZURE_API_KEY")
59
- if not self.api_key:
60
- raise Exception("A key should be provided to invoke the endpoint")
61
- for key in self.PARAMS:
62
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
63
-
64
- def close(self) -> None:
65
- """Close the client."""
66
-
67
- def get_generation_header(self) -> Dict[str, str]:
68
- """
69
- Get generation header.
70
-
71
- Returns:
72
- header.
73
- """
74
- return {'Content-Type':'application/json', 'Authorization':('Bearer '+ self.api_key), 'azureml-model-deployment': 'duckdb-nsql-v2-phi-medium-1' }
75
-
76
- def get_generation_url(self) -> str:
77
- """Get generation URL."""
78
- return self.host + "/score"
79
-
80
- def supports_batch_inference(self) -> bool:
81
- """Return whether the client supports batch inference."""
82
- return False
83
-
84
- def supports_streaming_inference(self) -> bool:
85
- """Return whether the client supports streaming inference.
86
-
87
- Override in child client class.
88
- """
89
- return True
90
-
91
- def get_model_params(self) -> Dict:
92
- """
93
- Get model params.
94
-
95
- By getting model params from the server, we can add to request
96
- and make sure cache keys are unique to model.
97
-
98
- Returns:
99
- model params.
100
- """
101
- return {"model_name": AzureEndpointClient.NAME, "engine": getattr(self, 'engine')}
102
-
103
- def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
104
- """
105
- Preprocess request params.
106
-
107
- Args:
108
- request: request params.
109
-
110
- Returns:
111
- request params.
112
- """
113
- # Format for chat model
114
- request = copy.deepcopy(request)
115
- prompt = request.pop("prompt")
116
- data = {"input_data": {"input_string": [{"role": "user", "content": prompt}], "parameters": {"stop":"\n```", "max_tokens": 500}}}
117
-
118
- #body = str(str.encode(json.dumps(data)))
119
- return super().preprocess_request_params(data)
120
-
121
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
122
- """
123
- Format response to dict.
124
-
125
- Args:
126
- response: response
127
- request: request
128
-
129
- Return:
130
- response as dict
131
- """
132
- new_choices = []
133
- response = copy.deepcopy(response)
134
- if "output" in response:
135
- new_choices.append({"text": response["output"]})
136
- else:
137
- new_choices.append({"text": ""})
138
- response["choices"] = new_choices
139
- return super().postprocess_response(response, request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/azureopenai.py DELETED
@@ -1,113 +0,0 @@
1
- """Azure client."""
2
- import logging
3
- import os
4
- from typing import Any, Dict, Optional, Type
5
-
6
- from manifest.clients.openai import OPENAI_ENGINES, OpenAIClient
7
- from manifest.request import LMRequest, Request
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- # Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
12
- # underscores ("_") may be used, except as ending characters. We create this mapping to
13
- # handle difference between Azure and OpenAI
14
- AZURE_DEPLOYMENT_NAME_MAPPING = {
15
- "gpt-3.5-turbo": "gpt-35-turbo",
16
- "gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
17
- }
18
- OPENAI_DEPLOYMENT_NAME_MAPPING = {
19
- "gpt-35-turbo": "gpt-3.5-turbo",
20
- "gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
21
- }
22
-
23
-
24
- class AzureClient(OpenAIClient):
25
- """Azure client."""
26
-
27
- PARAMS = OpenAIClient.PARAMS
28
- REQUEST_CLS: Type[Request] = LMRequest
29
- NAME = "azureopenai"
30
-
31
- def connect(
32
- self,
33
- connection_str: Optional[str] = None,
34
- client_args: Dict[str, Any] = {},
35
- ) -> None:
36
- """
37
- Connect to the AzureOpenAI server.
38
-
39
- connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
40
-
41
- Args:
42
- connection_str: connection string.
43
- client_args: client arguments.
44
- """
45
- self.api_key, self.host = None, None
46
- if connection_str:
47
- connection_parts = connection_str.split("::")
48
- if len(connection_parts) == 1:
49
- self.api_key = connection_parts[0]
50
- elif len(connection_parts) == 2:
51
- self.api_key, self.host = connection_parts
52
- else:
53
- raise ValueError(
54
- "Invalid connection string. "
55
- "Must be either AZURE_OPENAI_KEY or "
56
- "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
57
- )
58
- self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
59
- if self.api_key is None:
60
- raise ValueError(
61
- "AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
62
- "variable or pass through `client_connection`."
63
- )
64
- self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
65
- if self.host is None:
66
- raise ValueError(
67
- "Azure Service URL not set "
68
- "(e.g. https://openai-azure-service.openai.azure.com/)."
69
- " Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
70
- " as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
71
- )
72
- self.host = self.host.rstrip("/")
73
- for key in self.PARAMS:
74
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
75
- if getattr(self, "engine") not in OPENAI_ENGINES:
76
- raise ValueError(
77
- f"Invalid engine {getattr(self, 'engine')}. Must be {OPENAI_ENGINES}."
78
- )
79
-
80
- def get_generation_url(self) -> str:
81
- """Get generation URL."""
82
- engine = getattr(self, "engine")
83
- deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
84
- return (
85
- self.host
86
- + "/openai/deployments/"
87
- + deployment_name
88
- + "/completions?api-version=2023-05-15"
89
- )
90
-
91
- def get_generation_header(self) -> Dict[str, str]:
92
- """
93
- Get generation header.
94
-
95
- Returns:
96
- header.
97
- """
98
- return {"api-key": f"{self.api_key}"}
99
-
100
- def get_model_params(self) -> Dict:
101
- """
102
- Get model params.
103
-
104
- By getting model params from the server, we can add to request
105
- and make sure cache keys are unique to model.
106
-
107
- Returns:
108
- model params.
109
- """
110
- # IMPORTANT!!!
111
- # Azure models are the same as openai models. So we want to unify their
112
- # cached. Make sure we retrun the OpenAI name here.
113
- return {"model_name": OpenAIClient.NAME, "engine": getattr(self, "engine")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/azureopenai_chat.py DELETED
@@ -1,116 +0,0 @@
1
- """Azure client."""
2
- import logging
3
- import os
4
- from typing import Any, Dict, Optional
5
-
6
- from manifest.clients.openai_chat import OPENAICHAT_ENGINES, OpenAIChatClient
7
- from manifest.request import LMRequest
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- # Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
12
- # underscores ("_") may be used, except as ending characters. We create this mapping to
13
- # handle difference between Azure and OpenAI
14
- AZURE_DEPLOYMENT_NAME_MAPPING = {
15
- "gpt-3.5-turbo": "gpt-35-turbo",
16
- "gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
17
- }
18
- OPENAI_DEPLOYMENT_NAME_MAPPING = {
19
- "gpt-35-turbo": "gpt-3.5-turbo",
20
- "gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
21
- }
22
-
23
-
24
- class AzureChatClient(OpenAIChatClient):
25
- """Azure chat client."""
26
-
27
- # User param -> (client param, default value)
28
- PARAMS = OpenAIChatClient.PARAMS
29
- REQUEST_CLS = LMRequest
30
- NAME = "azureopenaichat"
31
- IS_CHAT = True
32
-
33
- def connect(
34
- self,
35
- connection_str: Optional[str] = None,
36
- client_args: Dict[str, Any] = {},
37
- ) -> None:
38
- """
39
- Connect to the AzureOpenAI server.
40
-
41
- connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
42
-
43
- Args:
44
- connection_str: connection string.
45
- client_args: client arguments.
46
- """
47
- self.api_key, self.host = None, None
48
- if connection_str:
49
- connection_parts = connection_str.split("::")
50
- if len(connection_parts) == 1:
51
- self.api_key = connection_parts[0]
52
- elif len(connection_parts) == 2:
53
- self.api_key, self.host = connection_parts
54
- else:
55
- raise ValueError(
56
- "Invalid connection string. "
57
- "Must be either AZURE_OPENAI_KEY or "
58
- "AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
59
- )
60
- self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
61
- if self.api_key is None:
62
- raise ValueError(
63
- "AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
64
- "variable or pass through `client_connection`."
65
- )
66
- self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
67
- if self.host is None:
68
- raise ValueError(
69
- "Azure Service URL not set "
70
- "(e.g. https://openai-azure-service.openai.azure.com/)."
71
- " Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
72
- " as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
73
- )
74
- self.host = self.host.rstrip("/")
75
- for key in self.PARAMS:
76
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
77
- if getattr(self, "engine") not in OPENAICHAT_ENGINES:
78
- raise ValueError(
79
- f"Invalid engine {getattr(self, 'engine')}. "
80
- f"Must be {OPENAICHAT_ENGINES}."
81
- )
82
-
83
- def get_generation_url(self) -> str:
84
- """Get generation URL."""
85
- engine = getattr(self, "engine")
86
- deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
87
- return (
88
- self.host
89
- + "/openai/deployments/"
90
- + deployment_name
91
- + "/chat/completions?api-version=2023-05-15"
92
- )
93
-
94
- def get_generation_header(self) -> Dict[str, str]:
95
- """
96
- Get generation header.
97
-
98
- Returns:
99
- header.
100
- """
101
- return {"api-key": f"{self.api_key}"}
102
-
103
- def get_model_params(self) -> Dict:
104
- """
105
- Get model params.
106
-
107
- By getting model params from the server, we can add to request
108
- and make sure cache keys are unique to model.
109
-
110
- Returns:
111
- model params.
112
- """
113
- # IMPORTANT!!!
114
- # Azure models are the same as openai models. So we want to unify their
115
- # cached. Make sure we retrun the OpenAI name here.
116
- return {"model_name": OpenAIChatClient.NAME, "engine": getattr(self, "engine")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/client.py DELETED
@@ -1,699 +0,0 @@
1
- """Client class."""
2
- import asyncio
3
- import copy
4
- import json
5
- import logging
6
- import math
7
- from abc import ABC, abstractmethod
8
- from typing import Any, Dict, Generator, List, Optional, Tuple, Union, cast
9
-
10
- import aiohttp
11
- import requests
12
- import tqdm.asyncio
13
- from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
14
-
15
- from manifest.request import (
16
- DEFAULT_REQUEST_KEYS,
17
- NOT_CACHE_KEYS,
18
- LMChatRequest,
19
- LMRequest,
20
- LMScoreRequest,
21
- Request,
22
- )
23
- from manifest.response import (
24
- RESPONSE_CONSTRUCTORS,
25
- ArrayModelChoice,
26
- LMModelChoice,
27
- ModelChoices,
28
- Response,
29
- Usage,
30
- Usages,
31
- )
32
-
33
- logger = logging.getLogger(__name__)
34
-
35
- ATTEMPTS_BEFORE_STOP = 4
36
- ATTEMPTS_TIMEOUT = 30
37
- # http_status mainly for azure and e.code mainly for openai usage
38
- # e.http_status == 408 occurs when Azure times out
39
- # e.code == 429 rate lime
40
- # e.code == 500 or 502 occurs when server error
41
- API_ERROR_CODE = {408, 429, 500, 502, 520, 524}
42
-
43
-
44
- def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
45
- """Return whether to retry if ratelimited."""
46
- try:
47
- if isinstance(retry_base.outcome.exception(), requests.exceptions.HTTPError):
48
- exception = cast(
49
- requests.exceptions.HTTPError, retry_base.outcome.exception()
50
- )
51
- # 500 is a server error, 429 is a rate limit error
52
- if exception.response.status_code in API_ERROR_CODE: # type: ignore
53
- return True
54
- except Exception:
55
- pass
56
- return True
57
-
58
-
59
- def return_error_response(retry_state: RetryCallState) -> dict:
60
- """Return error response if all retries failed."""
61
- request_params = retry_state.args[1]
62
- number_of_prompts = (
63
- len(request_params["prompt"])
64
- if "prompt" in request_params
65
- else len(request_params["messages"])
66
- )
67
- return {
68
- "choices": [],
69
- "usage": {
70
- "total_tokens": 0,
71
- "prompt_tokens": 0,
72
- "completion_tokens": 0,
73
- },
74
- "errors": [str(retry_state.outcome.exception())] * number_of_prompts,
75
- }
76
-
77
-
78
- class Client(ABC):
79
- """Client class."""
80
-
81
- # Must be overridden by child class
82
- PARAMS: Dict[str, Tuple[str, Any]] = {}
83
- REQUEST_CLS = Request
84
- NAME: str = None
85
- IS_CHAT: bool = False
86
-
87
- def __init__(
88
- self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
89
- ):
90
- """
91
- Initialize client.
92
-
93
- kwargs are passed to client as default parameters.
94
-
95
- For clients like OpenAI that do not require a connection,
96
- the connection_str can be None.
97
-
98
- Args:
99
- connection_str: connection string for client.
100
- client_args: client arguments.
101
- """
102
- self.connect(connection_str, client_args)
103
-
104
- @abstractmethod
105
- def connect(
106
- self, connection_str: Optional[str], client_args: Dict[str, Any]
107
- ) -> None:
108
- """
109
- Connect to client.
110
-
111
- Override in child client class.
112
- Args:
113
- connection_str: connection string.
114
- """
115
- raise NotImplementedError()
116
-
117
- @abstractmethod
118
- def close(self) -> None:
119
- """Close the client.
120
-
121
- Override in child client class.
122
- """
123
- raise NotImplementedError()
124
-
125
- @abstractmethod
126
- def get_generation_url(self) -> str:
127
- """Get generation URL.
128
-
129
- Override in child client class.
130
- """
131
- raise NotImplementedError()
132
-
133
- @abstractmethod
134
- def get_generation_header(self) -> Dict[str, str]:
135
- """
136
- Get generation header.
137
-
138
- Override in child client class.
139
- Returns:
140
- header.
141
- """
142
- raise NotImplementedError()
143
-
144
- @abstractmethod
145
- def supports_batch_inference(self) -> bool:
146
- """Return whether the client supports batch inference.
147
-
148
- Override in child client class.
149
- """
150
- raise NotImplementedError()
151
-
152
- @abstractmethod
153
- def supports_streaming_inference(self) -> bool:
154
- """Return whether the client supports streaming inference.
155
-
156
- Override in child client class.
157
- """
158
- raise NotImplementedError()
159
-
160
- @abstractmethod
161
- def get_model_params(self) -> Dict:
162
- """
163
- Get model params.
164
-
165
- By getting model params from the server, we can add to request
166
- and make sure cache keys are unique to model.
167
-
168
- Override in child client class.
169
- Returns:
170
- model params.
171
- """
172
- raise NotImplementedError()
173
-
174
- def get_tokenizer(self, model: str) -> Tuple[Any, int]:
175
- """Get tokenizer for model.
176
-
177
- Override in child client class. Return None, -1 if not supported
178
- or no prompt truncation required.
179
- Returns:
180
- tokenizer: tokenizer with encoder and decode
181
- max_length: max length of model
182
- """
183
- return None, -1
184
-
185
- def get_model_inputs(self) -> List:
186
- """
187
- Get allowable model inputs.
188
-
189
- Returns:
190
- model inputs.
191
- """
192
- return list(self.PARAMS.keys())
193
-
194
- def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
195
- """Split usage into list of usages for each prompt."""
196
- # TODO: add this in using default tokenizer
197
- return []
198
-
199
- def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
200
- """
201
- Preprocess request params.
202
-
203
- Override in child client class to reformat requests to model.
204
-
205
- Args:
206
- request: request params.
207
-
208
- Returns:
209
- request params.
210
- """
211
- return request
212
-
213
- def postprocess_response(
214
- self, response: Dict[str, Any], request: Dict[str, Any]
215
- ) -> Dict[str, Any]:
216
- """
217
- Postprocess and validate response as dict.
218
-
219
- Override in child client class to reform model responses.
220
-
221
- Args:
222
- response: response
223
- request: request
224
-
225
- Return:
226
- response as dict
227
- """
228
- if "choices" not in response:
229
- raise ValueError(f"Invalid response: {response}")
230
- if "usage" in response:
231
- # Handle splitting the usages for batch requests
232
- if len(response["choices"]) == 1:
233
- if isinstance(response["usage"], list):
234
- response["usage"] = response["usage"][0]
235
- response["usage"] = [response["usage"]]
236
- else:
237
- # Try to split usage
238
- split_usage = self.split_usage(request, response["choices"])
239
- if split_usage:
240
- response["usage"] = split_usage
241
- return response
242
-
243
- def get_request(
244
- self, prompt: Union[str, List[str]], request_args: Dict[str, Any]
245
- ) -> Request:
246
- """
247
- Parse model kwargs to request.
248
-
249
- Args:
250
- prompt: prompt.
251
- request_args: request arguments.
252
-
253
- Returns:
254
- request.
255
- """
256
- params = {"prompt": prompt}
257
- # Adds default values from self.PARAMS if not in request_args
258
- for key in self.PARAMS:
259
- params[key] = request_args.pop(key, getattr(self, key))
260
- # Allows for overriding DEFAULT_REQUEST_KEYS even if they are not
261
- # in self.PARAMS. Note that DEFAULT_REQUEST_KEYS match the default
262
- # values in Request.
263
- for key in DEFAULT_REQUEST_KEYS:
264
- if key not in params and key in request_args:
265
- params[key] = request_args.pop(key)
266
- return self.REQUEST_CLS(**params) # type: ignore
267
-
268
- def _get_request_params(self, request: Request) -> Dict[str, Any]:
269
- """Get request params.
270
-
271
- Add default keys that we need for requests such as batch_size.
272
- We drop these before sending to the model.
273
- """
274
- params_to_add = DEFAULT_REQUEST_KEYS.copy()
275
- # This will override DEFAULT_REQUEST_KEYS with those in PARAMS
276
- params_to_add.update(self.PARAMS)
277
- # to_dict will handle parameter renaming but not any
278
- # default value handling - that is done in get_request()
279
- request_params = request.to_dict(params_to_add)
280
- return request_params
281
-
282
- def get_cache_key(self, request: Request) -> Dict[str, Any]:
283
- """Get cache key for request.
284
-
285
- Skip keys that are not cache keys such as batch_size.
286
- """
287
- request_params = self._get_request_params(request)
288
- for key in NOT_CACHE_KEYS:
289
- request_params.pop(key, None)
290
- # Make sure to add model params and request class
291
- request_params.update(self.get_model_params())
292
- request_params["request_cls"] = request.__class__.__name__
293
- return request_params
294
-
295
- def _split_requests(
296
- self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt"
297
- ) -> List[Dict[str, Any]]:
298
- """Split request into batch_sized request.
299
-
300
- Args:
301
- request_params: request params.
302
- batch_size: batch size for requests.
303
- key: key to batch over
304
-
305
- Returns:
306
- list of request params.
307
- """
308
- data = copy.deepcopy(request_params[key])
309
- data_size = len(request_params[key])
310
- request_params_list = []
311
- for i in range(0, data_size, batch_size):
312
- params = copy.deepcopy(request_params)
313
- params[key] = data[i] if batch_size == 1 else data[i : i + batch_size]
314
- request_params_list.append(params)
315
- return request_params_list
316
-
317
- def _get_model_choices(self, response: Dict) -> ModelChoices:
318
- """Format response to ModelChoices."""
319
- # Array or text response
320
- response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
321
- if response_type == "array":
322
- choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
323
- ArrayModelChoice(**choice) for choice in response["choices"]
324
- ]
325
- else:
326
- choices = [LMModelChoice(**choice) for choice in response["choices"]]
327
- return ModelChoices(choices=choices)
328
-
329
- def _stitch_responses(self, request: Request, responses: List[Dict]) -> Response:
330
- """Stitch responses together.
331
-
332
- Useful for batch requests.
333
- """
334
- choices = []
335
- usages = []
336
- for res_dict in responses:
337
- choices.extend(res_dict["choices"])
338
- if "usage" in res_dict:
339
- usages.extend(res_dict["usage"])
340
- final_response_dict = {"choices": choices}
341
- final_usages = None
342
- if usages:
343
- final_usages = Usages(usages=[Usage(**usage) for usage in usages])
344
- # TODO: Add usage based on tokenizer
345
- return Response(
346
- self._get_model_choices(final_response_dict),
347
- cached=False,
348
- request=request,
349
- usages=final_usages,
350
- **RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
351
- )
352
-
353
- def _verify_request_lengths(
354
- self, request: Dict[str, Any], model: str, max_tokens: int
355
- ) -> None:
356
- """Verify that the request length is not too long."""
357
- encoder, max_length = self.get_tokenizer(model)
358
- if not encoder or max_length < 0:
359
- return
360
- if isinstance(request["prompt"], str):
361
- prompts = [request["prompt"]]
362
- else:
363
- prompts = request["prompt"]
364
- for i in range(len(prompts)):
365
- prompt = prompts[i]
366
- encoded_prompt = encoder.encode(prompt)
367
- if len(encoded_prompt) + max_tokens > max_length:
368
- logger.warning(
369
- f"Prompt {prompt} is too long for model {model}. "
370
- "Truncating prompt from left."
371
- )
372
- # -20 to be safe
373
- prompt = encoder.decode(
374
- encoded_prompt[-int(max_length - max_tokens - 20) :]
375
- )
376
- prompts[i] = prompt
377
- if isinstance(request["prompt"], str):
378
- request["prompt"] = prompts[0]
379
- else:
380
- request["prompt"] = prompts
381
-
382
- @retry(
383
- reraise=True,
384
- wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
385
- stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
386
- )
387
- def _run_completion(
388
- self, request_params: Dict[str, Any], retry_timeout: int
389
- ) -> Dict:
390
- """Execute completion request.
391
-
392
- Args:
393
- request_params: request params.
394
- retry_timeout: retry timeout.
395
-
396
- Returns:
397
- response as dict.
398
- """
399
- request_params = self.preprocess_request_params(request_params)
400
- print(request_params)
401
- post_str = self.get_generation_url()
402
- res = requests.post(
403
- post_str,
404
- headers=self.get_generation_header(),
405
- json=request_params,
406
- timeout=retry_timeout,
407
- )
408
- try:
409
- res.raise_for_status()
410
- except requests.exceptions.HTTPError as e:
411
- logger.warning(
412
- str(e)
413
- )
414
- raise Exception()
415
- return self.postprocess_response(res.json(), request_params)
416
-
417
- @retry(
418
- reraise=True,
419
- retry=retry_if_ratelimit,
420
- wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
421
- stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
422
- )
423
- async def _arun_completion(
424
- self, request_params: Dict[str, Any], retry_timeout: int
425
- ) -> Dict:
426
- """Async execute completion request.
427
-
428
- Args:
429
- request_params: request params.
430
- retry_timeout: retry timeout.
431
-
432
- Returns:
433
- response as dict.
434
- """
435
- request_params = self.preprocess_request_params(request_params)
436
- post_str = self.get_generation_url()
437
- async with aiohttp.ClientSession(timeout=retry_timeout) as session:
438
- async with session.post(
439
- post_str,
440
- headers=self.get_generation_header(),
441
- json=request_params,
442
- timeout=retry_timeout,
443
- ) as res:
444
- res.raise_for_status()
445
- res_json = await res.json(content_type=None)
446
- return self.postprocess_response(res_json, request_params)
447
-
448
- @retry(
449
- reraise=True,
450
- retry=retry_if_ratelimit,
451
- wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
452
- stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
453
- )
454
- def _run_streaming_completion(
455
- self, request_params: Dict[str, Any], retry_timeout: int
456
- ) -> Generator[Dict, None, None]:
457
- """Execute completion request streaming.
458
-
459
- Args:
460
- request_params: request params.
461
- retry_timeout: retry timeout.
462
-
463
- Returns:
464
- response as dict.
465
- """
466
- request_params = self.preprocess_request_params(request_params)
467
- request_params["stream"] = True
468
- post_str = self.get_generation_url()
469
- res_iter = requests.post(
470
- post_str,
471
- headers=self.get_generation_header(),
472
- json=request_params,
473
- timeout=retry_timeout,
474
- stream=True,
475
- )
476
- for res_token in res_iter.iter_lines():
477
- if res_token:
478
- decoded_res_token = res_token.decode("utf-8")
479
- decoded_res_token = decoded_res_token.replace("data: ", "")
480
- if decoded_res_token == "[DONE]":
481
- break
482
- try:
483
- decoded_res_token_dct = json.loads(decoded_res_token)
484
- postprocess_res_token_dct = self.postprocess_response(
485
- decoded_res_token_dct, request_params
486
- )
487
- # If nothing is returned, skip
488
- if (
489
- not postprocess_res_token_dct
490
- or not postprocess_res_token_dct["choices"]
491
- ):
492
- continue
493
- yield postprocess_res_token_dct
494
- except Exception as e:
495
- raise e
496
-
497
- def run_request(self, request: Request) -> Response:
498
- """
499
- Run request.
500
-
501
- Args:
502
- request: request.
503
-
504
- Returns:
505
- response.
506
- """
507
- # Make everything list for consistency
508
- if isinstance(request.prompt, list):
509
- prompt_list = request.prompt
510
- else:
511
- prompt_list = [request.prompt]
512
-
513
- request_params = self._get_request_params(request)
514
- # Set the params as a list. Do not set the request
515
- # object itself as the cache will then store it as a
516
- # list which is inconsistent with the request input.
517
- request_params["prompt"] = prompt_list
518
-
519
- # If batch_size is not set, set it to 1
520
- batch_size = request_params.pop("batch_size") or 1
521
- if not self.supports_batch_inference() and batch_size != 1:
522
- logger.warning(
523
- f"{self.__class__.__name__} does not support batch inference."
524
- " Setting batch size to 1"
525
- )
526
- batch_size = 1
527
-
528
- # Take the default keys we need and drop the rest as they
529
- # are not part of the model request.
530
- retry_timeout = request_params.pop("client_timeout")
531
- for key in DEFAULT_REQUEST_KEYS:
532
- request_params.pop(key, None)
533
-
534
- # Make sure requests are in the request length
535
- # If no tokenizer is set or not LM request, this
536
- # will do nothing
537
- if isinstance(request, LMRequest):
538
- self._verify_request_lengths(
539
- request_params, model=request.engine, max_tokens=request.max_tokens
540
- )
541
-
542
- # Batch requests
543
- num_batches = len(prompt_list) // batch_size
544
- if len(prompt_list) % batch_size != 0:
545
- batch_size = int(math.ceil(len(prompt_list) / (num_batches + 1)))
546
- request_batches = self._split_requests(request_params, batch_size)
547
-
548
- response_dicts = [
549
- self._run_completion(batch, retry_timeout) for batch in request_batches
550
- ]
551
- # Flatten responses
552
- return self._stitch_responses(request, response_dicts)
553
-
554
- async def arun_batch_request(
555
- self, request: Request, verbose: bool = False
556
- ) -> Response:
557
- """
558
- Run async request.
559
-
560
- Args:
561
- request: request.s
562
-
563
- Returns:
564
- response.
565
- """
566
- required_batch_size = None
567
- if not self.supports_batch_inference():
568
- required_batch_size = 1
569
- if not isinstance(request.prompt, list):
570
- raise AssertionError(
571
- "request.prompt must be a list for async batch inference."
572
- )
573
-
574
- request_params = self._get_request_params(request)
575
- # Take the default keys we need and drop the rest as they
576
- # are not part of the model request.
577
- retry_timeout = request_params.pop("client_timeout")
578
- batch_size = request_params.pop("batch_size")
579
- batch_size = required_batch_size or batch_size
580
- for key in DEFAULT_REQUEST_KEYS:
581
- request_params.pop(key, None)
582
-
583
- # Make sure requests are in the request length
584
- # If no tokenizer is set or not LM request, this
585
- # will do nothing
586
- if isinstance(request, LMRequest):
587
- self._verify_request_lengths(
588
- request_params, model=request.engine, max_tokens=request.max_tokens
589
- )
590
-
591
- # Batch requests
592
- num_batches = len(request.prompt) // batch_size
593
- if len(request.prompt) % batch_size != 0:
594
- batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1)))
595
-
596
- request_batches = self._split_requests(request_params, batch_size)
597
- all_tasks = [
598
- asyncio.create_task(self._arun_completion(batch, retry_timeout))
599
- for batch in request_batches
600
- ]
601
- responses = await tqdm.asyncio.tqdm.gather(*all_tasks, disable=not verbose)
602
- # Flatten responses
603
- return self._stitch_responses(request, responses)
604
-
605
- def run_chat_request(
606
- self,
607
- request: LMChatRequest,
608
- ) -> Response:
609
- """
610
- Get the response from chat model.
611
-
612
- Args:
613
- request: request.
614
-
615
- Returns:
616
- response.
617
- """
618
- request_params = self._get_request_params(request)
619
- # Take the default keys we need and drop the rest as they
620
- # are not part of the model request.
621
- retry_timeout = request_params.pop("client_timeout")
622
- for key in DEFAULT_REQUEST_KEYS:
623
- request_params.pop(key, None)
624
-
625
- # Make sure requests are in the request length
626
- # If no tokenizer is set or not LM request, this
627
- # will do nothing
628
- self._verify_request_lengths(
629
- request_params, model=request.engine, max_tokens=request.max_tokens
630
- )
631
-
632
- response_dict = self._run_completion(request_params, retry_timeout)
633
- usages = None
634
- if "usage" in response_dict:
635
- usages = [Usage(**usage) for usage in response_dict["usage"]]
636
-
637
- return Response(
638
- response=self._get_model_choices(response_dict),
639
- cached=False,
640
- request=request,
641
- usages=Usages(usages=usages) if usages else None,
642
- **RESPONSE_CONSTRUCTORS[LMChatRequest], # type: ignore
643
- )
644
-
645
- def run_streaming_request(
646
- self, request: Request
647
- ) -> Generator[Response, None, None]:
648
- """
649
- Run streaming request.
650
-
651
- Args:
652
- request: request.
653
-
654
- Returns:
655
- response.
656
- """
657
- if not isinstance(request.prompt, str):
658
- raise ValueError("Streaming requests must have a single prompt.")
659
- if not self.supports_streaming_inference():
660
- raise ValueError(
661
- f"{self.__class__.__name__} does not support streaming inference."
662
- )
663
- request_params = self._get_request_params(request)
664
-
665
- # Take the default keys we need and drop the rest as they
666
- # are not part of the model request.
667
- retry_timeout = request_params.pop("client_timeout")
668
- for key in DEFAULT_REQUEST_KEYS:
669
- request_params.pop(key, None)
670
-
671
- # Make sure requests are in the request length
672
- # If no tokenizer is set or not LM request, this
673
- # will do nothing
674
- if isinstance(request, LMRequest):
675
- self._verify_request_lengths(
676
- request_params, model=request.engine, max_tokens=request.max_tokens
677
- )
678
-
679
- for token_response in self._run_streaming_completion(
680
- request_params, retry_timeout
681
- ):
682
- yield self._stitch_responses(request, [token_response])
683
-
684
- def run_score_prompt_request(
685
- self,
686
- request: LMScoreRequest,
687
- ) -> Response:
688
- """
689
- Get the logit score of the prompt via a forward pass of the model.
690
-
691
- Args:
692
- request: request.
693
-
694
- Returns:
695
- response.
696
- """
697
- raise NotImplementedError(
698
- f"{self.__class__.__name__} does not support prompt scoring request."
699
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/cohere.py DELETED
@@ -1,125 +0,0 @@
1
- """Cohere client."""
2
-
3
- import logging
4
- import os
5
- from typing import Any, Dict, Optional
6
-
7
- from manifest.clients.client import Client
8
- from manifest.request import LMRequest
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- COHERE_MODELS = {"small", "medium", "large", "xlarge"}
13
-
14
-
15
- class CohereClient(Client):
16
- """Cohere client."""
17
-
18
- # Params are defined in https://docs.cohere.ai/generate-reference
19
- PARAMS = {
20
- "engine": ("model", "xlarge"),
21
- "max_tokens": ("max_tokens", 20),
22
- "temperature": ("temperature", 0.75),
23
- "n": ("num_generations", 1),
24
- "top_k": ("k", 0),
25
- "top_p": ("p", 0.75),
26
- "frequency_penalty": ("frequency_penalty", 0.0),
27
- "presence_penalty": ("presence_penalty", 0.0),
28
- "stop_sequences": ("stop_sequences", None),
29
- }
30
- REQUEST_CLS = LMRequest
31
- NAME = "cohere"
32
-
33
- def connect(
34
- self,
35
- connection_str: Optional[str] = None,
36
- client_args: Dict[str, Any] = {},
37
- ) -> None:
38
- """
39
- Connect to the Cohere server.
40
-
41
- connection_str is passed as default COHERE_API_KEY if variable not set.
42
-
43
- Args:
44
- connection_str: connection string.
45
- client_args: client arguments.
46
- """
47
- self.api_key = connection_str or os.environ.get("COHERE_API_KEY")
48
- if self.api_key is None:
49
- raise ValueError(
50
- "Cohere API key not set. Set COHERE_API_KEY environment "
51
- "variable or pass through `client_connection`."
52
- )
53
- self.host = "https://api.cohere.ai"
54
- for key in self.PARAMS:
55
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
56
- if getattr(self, "engine") not in COHERE_MODELS:
57
- raise ValueError(
58
- f"Invalid engine {getattr(self, 'engine')}. Must be {COHERE_MODELS}."
59
- )
60
-
61
- def close(self) -> None:
62
- """Close the client."""
63
-
64
- def get_generation_url(self) -> str:
65
- """Get generation URL."""
66
- return self.host + "/generate"
67
-
68
- def get_generation_header(self) -> Dict[str, str]:
69
- """
70
- Get generation header.
71
-
72
- Returns:
73
- header.
74
- """
75
- return {
76
- "Cohere-Version": "2021-11-08",
77
- "Authorization": f"Bearer {self.api_key}",
78
- }
79
-
80
- def supports_batch_inference(self) -> bool:
81
- """Return whether the client supports batch inference."""
82
- return False
83
-
84
- def supports_streaming_inference(self) -> bool:
85
- """Return whether the client supports streaming inference.
86
-
87
- Override in child client class.
88
- """
89
- return False
90
-
91
- def get_model_params(self) -> Dict:
92
- """
93
- Get model params.
94
-
95
- By getting model params from the server, we can add to request
96
- and make sure cache keys are unique to model.
97
-
98
- Returns:
99
- model params.
100
- """
101
- return {"model_name": self.NAME, "engine": getattr(self, "engine")}
102
-
103
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
104
- """
105
- Format response to dict.
106
-
107
- Args:
108
- response: response
109
- request: request
110
-
111
- Return:
112
- response as dict
113
- """
114
- return {
115
- "object": "text_completion",
116
- "model": getattr(self, "engine"),
117
- "choices": [
118
- {
119
- "text": item["text"],
120
- "text_logprob": item.get("likelihood", None),
121
- "token_logprobs": item.get("token_likelihoods", None),
122
- }
123
- for item in response["generations"]
124
- ],
125
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/diffuser.py DELETED
@@ -1,112 +0,0 @@
1
- """Diffuser client."""
2
- import logging
3
- from functools import lru_cache
4
- from typing import Any, Dict, Optional
5
-
6
- import numpy as np
7
- import requests
8
-
9
- from manifest.clients.client import Client
10
- from manifest.request import DiffusionRequest
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class DiffuserClient(Client):
16
- """Diffuser client."""
17
-
18
- # User param -> (client param, default value)
19
- PARAMS = {
20
- "num_inference_steps": ("num_inference_steps", 50),
21
- "height": ("height", 512),
22
- "width": ("width", 512),
23
- "n": ("num_images_per_prompt", 1),
24
- "guidance_scale": ("guidance_scale", 7.5),
25
- "eta": ("eta", 0.0),
26
- }
27
- REQUEST_CLS = DiffusionRequest
28
- NAME = "diffuser"
29
-
30
- def connect(
31
- self,
32
- connection_str: Optional[str] = None,
33
- client_args: Dict[str, Any] = {},
34
- ) -> None:
35
- """
36
- Connect to the Diffuser url.
37
-
38
- Arsg:
39
- connection_str: connection string.
40
- client_args: client arguments.
41
- """
42
- self.host = connection_str.rstrip("/")
43
- for key in self.PARAMS:
44
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
45
- self.model_params = self.get_model_params()
46
-
47
- def to_numpy(self, image: np.ndarray) -> np.ndarray:
48
- """Convert a numpy image to a PIL image.
49
-
50
- Adapted from https://github.com/huggingface/diffusers/blob/src/diffusers/pipelines/pipeline_utils.py#L808 # noqa: E501
51
- """
52
- image = (image * 255).round().astype("uint8")
53
- return image
54
-
55
- def close(self) -> None:
56
- """Close the client."""
57
- pass
58
-
59
- def get_generation_url(self) -> str:
60
- """Get generation URL."""
61
- return self.host + "/completions"
62
-
63
- def get_generation_header(self) -> Dict[str, str]:
64
- """
65
- Get generation header.
66
-
67
- Returns:
68
- header.
69
- """
70
- return {}
71
-
72
- def supports_batch_inference(self) -> bool:
73
- """Return whether the client supports batch inference."""
74
- return True
75
-
76
- def supports_streaming_inference(self) -> bool:
77
- """Return whether the client supports streaming inference.
78
-
79
- Override in child client class.
80
- """
81
- return False
82
-
83
- @lru_cache(maxsize=1)
84
- def get_model_params(self) -> Dict:
85
- """
86
- Get model params.
87
-
88
- By getting model params from the server, we can add to request
89
- and make sure cache keys are unique to model.
90
-
91
- Returns:
92
- model params.
93
- """
94
- res = requests.post(self.host + "/params").json()
95
- res["client_name"] = self.NAME
96
- return res
97
-
98
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
99
- """
100
- Format response to dict.
101
-
102
- Args:
103
- response: response
104
- request: request
105
-
106
- Return:
107
- response as dict
108
- """
109
- # Convert array to np.array
110
- for choice in response["choices"]:
111
- choice["array"] = self.to_numpy(np.array(choice["array"]))
112
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/dummy.py DELETED
@@ -1,251 +0,0 @@
1
- """Dummy client."""
2
- import hashlib
3
- import logging
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import tiktoken
8
-
9
- from manifest.clients.client import Client
10
- from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request
11
- from manifest.response import LMModelChoice, ModelChoices, Response, Usage, Usages
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class DummyClient(Client):
17
- """Dummy client."""
18
-
19
- # User param -> (client param, default value)
20
- PARAMS = {
21
- "engine": ("model", "text-davinci-003"),
22
- "temperature": ("temperature", 0.0),
23
- "max_tokens": ("max_tokens", 10),
24
- "n": ("n", 1),
25
- "top_p": ("top_p", 1.0),
26
- "top_k": ("best_of", 1),
27
- "batch_size": ("batch_size", 20),
28
- }
29
- REQUEST_CLS = LMRequest
30
- NAME = "dummy"
31
-
32
- def connect(
33
- self,
34
- connection_str: Optional[str] = None,
35
- client_args: Dict[str, Any] = {},
36
- ) -> None:
37
- """
38
- Connect to dummpy server.
39
-
40
- This is a dummy client that returns identity responses. Used for testing.
41
-
42
- Args:
43
- connection_str: connection string.
44
- client_args: client arguments.
45
- """
46
- # We tiktoken as it is faster than HF for tokenizing
47
- # Use any model to create the tokenizer
48
- self.encoder = tiktoken.get_encoding("cl100k_base")
49
- for key in self.PARAMS:
50
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
51
-
52
- def close(self) -> None:
53
- """Close the client."""
54
- pass
55
-
56
- def get_generation_url(self) -> str:
57
- """Get generation URL."""
58
- return "dummy"
59
-
60
- def supports_batch_inference(self) -> bool:
61
- """Return whether the client supports batch inference."""
62
- return True
63
-
64
- def supports_streaming_inference(self) -> bool:
65
- """Return whether the client supports streaming inference.
66
-
67
- Override in child client class.
68
- """
69
- return False
70
-
71
- def get_generation_header(self) -> Dict[str, str]:
72
- """
73
- Get generation header.
74
-
75
- Returns:
76
- header.
77
- """
78
- return {}
79
-
80
- def get_model_params(self) -> Dict:
81
- """
82
- Get model params.
83
-
84
- By getting model params from the server, we can add to request
85
- and make sure cache keys are unique to model.
86
-
87
- Returns:
88
- model params.
89
- """
90
- return {"engine": "dummy", "model": getattr(self, "engine")}
91
-
92
- def get_mock_output(
93
- self, output_toks: int, is_completion: bool, seed: Optional[int] = None
94
- ) -> LMModelChoice:
95
- """Return mock model output by generating random tokens."""
96
- np.random.seed(seed)
97
- random_tokens = np.random.randint(
98
- 0, self.encoder.max_token_value + 1, output_toks
99
- )
100
- response = self.encoder.decode(random_tokens) # type: ignore
101
- if is_completion:
102
- np.random.seed(seed)
103
- random_logprobs = np.random.uniform(
104
- low=-2, high=-0.00001, size=output_toks
105
- ).tolist()
106
- else:
107
- # Return all Nones to mimic chat models
108
- # OpenAI chat models do not return logprobs
109
- random_logprobs = [None] * output_toks
110
- return LMModelChoice(
111
- text=response,
112
- token_logprobs=random_logprobs,
113
- tokens=random_tokens.tolist(),
114
- )
115
-
116
- def get_mock_choices(
117
- self,
118
- prompt_list: List[str],
119
- request_params: Dict,
120
- is_completion: bool,
121
- ) -> Tuple[List[LMModelChoice], List[Usage]]:
122
- """Get choices and usages of mock output."""
123
- choices = []
124
- usages = []
125
- for prompt in prompt_list:
126
- num_prompt_tokens = len(self.encoder.encode(prompt))
127
- if request_params["temperature"] == 0:
128
- # Get integer seed from hash of prompt
129
- seed = (
130
- int(hashlib.sha256(prompt.encode("utf-8")).hexdigest(), 16)
131
- % 10**8
132
- )
133
- else:
134
- # Get random seed
135
- seed = None
136
- for _ in range(int(request_params["n"])):
137
- choice = self.get_mock_output(
138
- request_params["max_tokens"], is_completion=is_completion, seed=seed
139
- )
140
- choices.append(choice)
141
- usages.append(
142
- Usage(
143
- prompt_tokens=num_prompt_tokens,
144
- completion_tokens=request_params["max_tokens"],
145
- total_tokens=num_prompt_tokens + request_params["max_tokens"],
146
- )
147
- )
148
- return choices, usages
149
-
150
- def run_request(self, request: Request) -> Response:
151
- """
152
- Get request string function.
153
-
154
- Args:
155
- request: request.
156
-
157
- Returns:
158
- request function that takes no input.
159
- request parameters as dict.
160
- """
161
- if isinstance(request.prompt, list):
162
- prompt_list = request.prompt
163
- else:
164
- prompt_list = [request.prompt]
165
- request_params = request.to_dict(self.PARAMS)
166
-
167
- choices, usages = self.get_mock_choices(
168
- prompt_list, request_params, is_completion=True
169
- )
170
- return Response(
171
- response=ModelChoices(choices=choices), # type: ignore
172
- cached=False,
173
- request=request,
174
- usages=Usages(usages=usages),
175
- response_type="text",
176
- request_type=self.REQUEST_CLS,
177
- )
178
-
179
- async def arun_batch_request(
180
- self, request: Request, verbose: bool = False
181
- ) -> Response:
182
- """
183
- Get async request string function.
184
-
185
- Args:
186
- request: request.
187
-
188
- Returns:
189
- response.
190
- """
191
- return self.run_request(request)
192
-
193
- def run_chat_request(
194
- self,
195
- request: LMChatRequest,
196
- ) -> Response:
197
- """
198
- Get the response from chat model.
199
-
200
- Args:
201
- request: request.
202
-
203
- Returns:
204
- response.
205
- """
206
- prompt_list = ["_".join(pmp["content"] for pmp in request.prompt)]
207
- request_params = request.to_dict(self.PARAMS)
208
-
209
- choices, usages = self.get_mock_choices(
210
- prompt_list, request_params, is_completion=False
211
- )
212
- return Response(
213
- response=ModelChoices(choices=choices), # type: ignore
214
- cached=False,
215
- request=request,
216
- usages=Usages(usages=usages),
217
- response_type="text",
218
- request_type=LMChatRequest,
219
- )
220
-
221
- def run_score_prompt_request(
222
- self,
223
- request: LMScoreRequest,
224
- ) -> Response:
225
- """
226
- Get the logit score of the prompt via a forward pass of the model.
227
-
228
- Args:
229
- request: request.
230
-
231
- Returns:
232
- request function that takes no input.
233
- request parameters as dict.
234
- """
235
- if isinstance(request.prompt, list):
236
- prompt_list = request.prompt
237
- else:
238
- prompt_list = [request.prompt]
239
- request_params = request.to_dict(self.PARAMS)
240
-
241
- choices, usages = self.get_mock_choices(
242
- prompt_list, request_params, is_completion=True
243
- )
244
- return Response(
245
- response=ModelChoices(choices=choices), # type: ignore
246
- cached=False,
247
- request=request,
248
- usages=Usages(usages=usages),
249
- response_type="text",
250
- request_type=LMScoreRequest,
251
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/google.py DELETED
@@ -1,197 +0,0 @@
1
- """Google client."""
2
- import logging
3
- import os
4
- import subprocess
5
- from typing import Any, Dict, Optional, Type
6
-
7
- from manifest.clients.client import Client
8
- from manifest.request import LMRequest, Request
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
13
- GOOGLE_ENGINES = {
14
- "text-bison",
15
- }
16
-
17
-
18
- def get_project_id() -> Optional[str]:
19
- """Get project ID.
20
-
21
- Run
22
- `gcloud config get-value project`
23
- """
24
- try:
25
- project_id = subprocess.run(
26
- ["gcloud", "config", "get-value", "project"],
27
- stdout=subprocess.PIPE,
28
- stderr=subprocess.PIPE,
29
- )
30
- if project_id.stderr.decode("utf-8").strip():
31
- return None
32
- return project_id.stdout.decode("utf-8").strip()
33
- except Exception:
34
- return None
35
-
36
-
37
- class GoogleClient(Client):
38
- """Google client."""
39
-
40
- # User param -> (client param, default value)
41
- PARAMS = {
42
- "engine": ("model", "text-bison"),
43
- "temperature": ("temperature", 1.0),
44
- "max_tokens": ("maxOutputTokens", 10),
45
- "top_p": ("topP", 1.0),
46
- "top_k": ("topK", 1),
47
- "batch_size": ("batch_size", 20),
48
- }
49
- REQUEST_CLS: Type[Request] = LMRequest
50
- NAME = "google"
51
-
52
- def connect(
53
- self,
54
- connection_str: Optional[str] = None,
55
- client_args: Dict[str, Any] = {},
56
- ) -> None:
57
- """
58
- Connect to the GoogleVertex API.
59
-
60
- connection_str is passed as default GOOGLE_API_KEY if variable not set.
61
-
62
- Args:
63
- connection_str: connection string.
64
- client_args: client arguments.
65
- """
66
- connection_parts = connection_str.split("::")
67
- if len(connection_parts) == 1:
68
- self.api_key = connection_parts[0]
69
- self.project_id = None
70
- elif len(connection_parts) == 2:
71
- self.api_key, self.project_id = connection_parts
72
- else:
73
- raise ValueError(
74
- "Invalid connection string. "
75
- "Must be either API_KEY or API_KEY::PROJECT_ID"
76
- )
77
- self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
78
- if self.api_key is None:
79
- raise ValueError(
80
- "GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
81
- "variable or pass through `client_connection`. This can be "
82
- "found by running `gcloud auth print-access-token`"
83
- )
84
- self.project_id = (
85
- self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
86
- )
87
- if self.project_id is None:
88
- raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
89
- self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
90
-
91
- for key in self.PARAMS:
92
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
93
- if getattr(self, "engine") not in GOOGLE_ENGINES:
94
- raise ValueError(
95
- f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
96
- )
97
-
98
- def close(self) -> None:
99
- """Close the client."""
100
- pass
101
-
102
- def get_generation_url(self) -> str:
103
- """Get generation URL."""
104
- model = getattr(self, "engine")
105
- return self.host + f"/{model}:predict"
106
-
107
- def get_generation_header(self) -> Dict[str, str]:
108
- """
109
- Get generation header.
110
-
111
- Returns:
112
- header.
113
- """
114
- return {"Authorization": f"Bearer {self.api_key}"}
115
-
116
- def supports_batch_inference(self) -> bool:
117
- """Return whether the client supports batch inference."""
118
- return True
119
-
120
- def supports_streaming_inference(self) -> bool:
121
- """Return whether the client supports streaming inference.
122
-
123
- Override in child client class.
124
- """
125
- return False
126
-
127
- def get_model_params(self) -> Dict:
128
- """
129
- Get model params.
130
-
131
- By getting model params from the server, we can add to request
132
- and make sure cache keys are unique to model.
133
-
134
- Returns:
135
- model params.
136
- """
137
- return {"model_name": self.NAME, "engine": getattr(self, "engine")}
138
-
139
- def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
140
- """
141
- Preprocess request params.
142
-
143
- Args:
144
- request: request params.
145
-
146
- Returns:
147
- request params.
148
- """
149
- # Refortmat the request params for google
150
- prompt = request.pop("prompt")
151
- if isinstance(prompt, str):
152
- prompt_list = [prompt]
153
- else:
154
- prompt_list = prompt
155
- google_request = {
156
- "instances": [{"prompt": prompt} for prompt in prompt_list],
157
- "parameters": request,
158
- }
159
- return super().preprocess_request_params(google_request)
160
-
161
- def postprocess_response(
162
- self, response: Dict[str, Any], request: Dict[str, Any]
163
- ) -> Dict[str, Any]:
164
- """
165
- Validate response as dict.
166
-
167
- Assumes response is dict
168
- {
169
- "predictions": [
170
- {
171
- "safetyAttributes": {
172
- "categories": ["Violent", "Sexual"],
173
- "blocked": false,
174
- "scores": [0.1, 0.1]
175
- },
176
- "content": "SELECT * FROM "WWW";"
177
- }
178
- ]
179
- }
180
-
181
- Args:
182
- response: response
183
- request: request
184
-
185
- Return:
186
- response as dict
187
- """
188
- google_predictions = response.pop("predictions")
189
- new_response = {
190
- "choices": [
191
- {
192
- "text": prediction["content"],
193
- }
194
- for prediction in google_predictions
195
- ]
196
- }
197
- return super().postprocess_response(new_response, request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/google_chat.py DELETED
@@ -1,155 +0,0 @@
1
- """Google client."""
2
- import copy
3
- import logging
4
- import os
5
- from typing import Any, Dict, Optional, Type
6
-
7
- from manifest.clients.google import GoogleClient, get_project_id
8
- from manifest.request import LMRequest, Request
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
13
- GOOGLE_ENGINES = {
14
- "chat-bison",
15
- }
16
-
17
-
18
- class GoogleChatClient(GoogleClient):
19
- """GoogleChat client."""
20
-
21
- # User param -> (client param, default value)
22
- PARAMS = {
23
- "engine": ("model", "chat-bison"),
24
- "temperature": ("temperature", 1.0),
25
- "max_tokens": ("maxOutputTokens", 10),
26
- "top_p": ("topP", 1.0),
27
- "top_k": ("topK", 1),
28
- "batch_size": ("batch_size", 20),
29
- }
30
- REQUEST_CLS: Type[Request] = LMRequest
31
- NAME = "googlechat"
32
- IS_CHAT = True
33
-
34
- def connect(
35
- self,
36
- connection_str: Optional[str] = None,
37
- client_args: Dict[str, Any] = {},
38
- ) -> None:
39
- """
40
- Connect to the GoogleVertex API.
41
-
42
- connection_str is passed as default GOOGLE_API_KEY if variable not set.
43
-
44
- Args:
45
- connection_str: connection string.
46
- client_args: client arguments.
47
- """
48
- connection_parts = connection_str.split("::")
49
- if len(connection_parts) == 1:
50
- self.api_key = connection_parts[0]
51
- elif len(connection_parts) == 2:
52
- self.api_key, self.project_id = connection_parts
53
- else:
54
- raise ValueError(
55
- "Invalid connection string. "
56
- "Must be either API_KEY or API_KEY::PROJECT_ID"
57
- )
58
- self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
59
- if self.api_key is None:
60
- raise ValueError(
61
- "GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
62
- "variable or pass through `client_connection`. This can be "
63
- "found by running `gcloud auth print-access-token`"
64
- )
65
- self.project_id = (
66
- self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
67
- )
68
- if self.project_id is None:
69
- raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
70
- self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
71
-
72
- for key in self.PARAMS:
73
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
74
- if getattr(self, "engine") not in GOOGLE_ENGINES:
75
- raise ValueError(
76
- f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
77
- )
78
-
79
- def supports_batch_inference(self) -> bool:
80
- """Return whether the client supports batch inference."""
81
- return False
82
-
83
- def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
84
- """
85
- Preprocess request params.
86
-
87
- Args:
88
- request: request params.
89
-
90
- Returns:
91
- request params.
92
- """
93
- # Format for chat model
94
- request = copy.deepcopy(request)
95
- prompt = request.pop("prompt")
96
- if isinstance(prompt, str):
97
- messages = [{"author": "user", "content": prompt}]
98
- elif isinstance(prompt, list) and isinstance(prompt[0], str):
99
- prompt_list = prompt
100
- messages = [{"author": "user", "content": prompt} for prompt in prompt_list]
101
- elif isinstance(prompt, list) and isinstance(prompt[0], dict):
102
- for pmt_dict in prompt:
103
- if "author" not in pmt_dict or "content" not in pmt_dict:
104
- raise ValueError(
105
- "Prompt must be list of dicts with 'author' and 'content' "
106
- f"keys. Got {prompt}."
107
- )
108
- messages = prompt
109
- else:
110
- raise ValueError(
111
- "Prompt must be string, list of strings, or list of dicts."
112
- f"Got {prompt}"
113
- )
114
- new_request = {
115
- "instances": [{"messages": messages}],
116
- "parameters": request,
117
- }
118
- return super(GoogleClient, self).preprocess_request_params(new_request)
119
-
120
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
121
- """
122
- Validate response as dict.
123
-
124
- Assumes response is dict
125
- {
126
- "candidates": [
127
- {
128
- "safetyAttributes": {
129
- "categories": ["Violent", "Sexual"],
130
- "blocked": false,
131
- "scores": [0.1, 0.1]
132
- },
133
- "author": "1",
134
- "content": "SELECT * FROM "WWW";"
135
- }
136
- ]
137
- }
138
-
139
- Args:
140
- response: response
141
- request: request
142
-
143
- Return:
144
- response as dict
145
- """
146
- google_predictions = response.pop("predictions")
147
- new_response = {
148
- "choices": [
149
- {
150
- "text": prediction["candidates"][0]["content"],
151
- }
152
- for prediction in google_predictions
153
- ]
154
- }
155
- return super(GoogleClient, self).postprocess_response(new_response, request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/huggingface.py DELETED
@@ -1,137 +0,0 @@
1
- """Hugging Face client."""
2
- import logging
3
- from functools import lru_cache
4
- from typing import Any, Dict, Optional
5
-
6
- import requests
7
-
8
- from manifest.clients.client import Client
9
- from manifest.request import DEFAULT_REQUEST_KEYS, LMRequest, LMScoreRequest
10
- from manifest.response import LMModelChoice, ModelChoices, Response
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class HuggingFaceClient(Client):
16
- """HuggingFace client."""
17
-
18
- # User param -> (client param, default value)
19
- PARAMS = {
20
- "temperature": ("temperature", 0.1),
21
- "max_tokens": ("max_tokens", 10),
22
- "n": ("n", 1),
23
- "top_p": ("top_p", 1.0),
24
- "top_k": ("top_k", 50),
25
- "repetition_penalty": ("repetition_penalty", 1.0),
26
- "do_sample": ("do_sample", True),
27
- }
28
- REQUEST_CLS = LMRequest
29
- NAME = "huggingface"
30
-
31
- def connect(
32
- self,
33
- connection_str: Optional[str] = None,
34
- client_args: Dict[str, Any] = {},
35
- ) -> None:
36
- """
37
- Connect to the HuggingFace url.
38
-
39
- Arsg:
40
- connection_str: connection string.
41
- client_args: client arguments.
42
- """
43
- if not connection_str:
44
- raise ValueError("Must provide connection string")
45
- self.host = connection_str.rstrip("/")
46
- for key in self.PARAMS:
47
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
48
-
49
- def close(self) -> None:
50
- """Close the client."""
51
- pass
52
-
53
- def get_generation_url(self) -> str:
54
- """Get generation URL."""
55
- return self.host + "/completions"
56
-
57
- def get_generation_header(self) -> Dict[str, str]:
58
- """
59
- Get generation header.
60
-
61
- Returns:
62
- header.
63
- """
64
- return {}
65
-
66
- def supports_batch_inference(self) -> bool:
67
- """Return whether the client supports batch inference."""
68
- return True
69
-
70
- def supports_streaming_inference(self) -> bool:
71
- """Return whether the client supports streaming inference.
72
-
73
- Override in child client class.
74
- """
75
- return False
76
-
77
- @lru_cache(maxsize=1)
78
- def get_model_params(self) -> Dict:
79
- """
80
- Get model params.
81
-
82
- By getting model params from the server, we can add to request
83
- and make sure cache keys are unique to model.
84
-
85
- Returns:
86
- model params.
87
- """
88
- res = requests.post(self.host + "/params").json()
89
- res["client_name"] = self.NAME
90
- return res
91
-
92
- def run_score_prompt_request(
93
- self,
94
- request: LMScoreRequest,
95
- ) -> Response:
96
- """
97
- Get the logit score of the prompt via a forward pass of the model.
98
-
99
- Args:
100
- request: request.
101
-
102
- Returns:
103
- request function that takes no input.
104
- request parameters as dict.
105
- """
106
- request_params = self._get_request_params(request)
107
- retry_timeout = request_params.pop("client_timeout")
108
- for key in DEFAULT_REQUEST_KEYS:
109
- request_params.pop(key, None)
110
- # Do not add params like we do with request as the model isn't sampling
111
- request_params = {"prompt": request.prompt}
112
-
113
- post_str = self.host + "/score_sequence"
114
- try:
115
- res = requests.post(
116
- post_str,
117
- json=request_params,
118
- timeout=retry_timeout,
119
- )
120
- res.raise_for_status()
121
- except requests.Timeout as e:
122
- logger.error("HF request timed out. Increase client_timeout.")
123
- raise e
124
- except requests.exceptions.HTTPError as e:
125
- logger.error(res.text)
126
- raise e
127
- response_dict = res.json()
128
- return Response(
129
- response=ModelChoices(
130
- choices=[LMModelChoice(**choice) for choice in response_dict["choices"]]
131
- ),
132
- cached=False,
133
- request=request,
134
- usages=None,
135
- response_type="text",
136
- request_type=LMScoreRequest,
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/huggingface_embedding.py DELETED
@@ -1,98 +0,0 @@
1
- """Hugging Face client."""
2
- import logging
3
- from functools import lru_cache
4
- from typing import Any, Dict, Optional, Tuple
5
-
6
- import numpy as np
7
- import requests
8
-
9
- from manifest.clients.client import Client
10
- from manifest.request import EmbeddingRequest
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class HuggingFaceEmbeddingClient(Client):
16
- """HuggingFaceEmbedding client."""
17
-
18
- # User param -> (client param, default value)
19
- PARAMS: Dict[str, Tuple[str, Any]] = {}
20
- REQUEST_CLS = EmbeddingRequest
21
- NAME = "huggingfaceembedding"
22
-
23
- def connect(
24
- self,
25
- connection_str: Optional[str] = None,
26
- client_args: Dict[str, Any] = {},
27
- ) -> None:
28
- """
29
- Connect to the HuggingFace url.
30
-
31
- Arsg:
32
- connection_str: connection string.
33
- client_args: client arguments.
34
- """
35
- if not connection_str:
36
- raise ValueError("Must provide connection string")
37
- self.host = connection_str.rstrip("/")
38
- for key in self.PARAMS:
39
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
40
-
41
- def close(self) -> None:
42
- """Close the client."""
43
- pass
44
-
45
- def get_generation_url(self) -> str:
46
- """Get generation URL."""
47
- return self.host + "/embed"
48
-
49
- def get_generation_header(self) -> Dict[str, str]:
50
- """
51
- Get generation header.
52
-
53
- Returns:
54
- header.
55
- """
56
- return {}
57
-
58
- def supports_batch_inference(self) -> bool:
59
- """Return whether the client supports batch inference."""
60
- return True
61
-
62
- def supports_streaming_inference(self) -> bool:
63
- """Return whether the client supports streaming inference.
64
-
65
- Override in child client class.
66
- """
67
- return False
68
-
69
- @lru_cache(maxsize=1)
70
- def get_model_params(self) -> Dict:
71
- """
72
- Get model params.
73
-
74
- By getting model params from the server, we can add to request
75
- and make sure cache keys are unique to model.
76
-
77
- Returns:
78
- model params.
79
- """
80
- res = requests.post(self.host + "/params").json()
81
- res["client_name"] = self.NAME
82
- return res
83
-
84
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
85
- """
86
- Format response to dict.
87
-
88
- Args:
89
- response: response
90
- request: request
91
-
92
- Return:
93
- response as dict
94
- """
95
- # Convert array to np.array
96
- for choice in response["choices"]:
97
- choice["array"] = np.array(choice["array"])
98
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
duckdb-nsql/manifest/manifest/clients/openai.py DELETED
@@ -1,162 +0,0 @@
1
- """OpenAI client."""
2
- import logging
3
- import os
4
- from typing import Any, Dict, List, Optional, Type
5
-
6
- import tiktoken
7
-
8
- from manifest.clients.client import Client
9
- from manifest.request import LMRequest, Request
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
- OPENAI_ENGINES = {
14
- "gpt-3.5-turbo-instruct",
15
- "text-davinci-003",
16
- "text-davinci-002",
17
- "text-davinci-001",
18
- "davinci",
19
- "curie",
20
- "ada",
21
- "babbage",
22
- "text-curie-001",
23
- "text-babbage-001",
24
- "text-ada-001",
25
- "code-davinci-002",
26
- "code-cushman-001",
27
- }
28
-
29
-
30
- class OpenAIClient(Client):
31
- """OpenAI client."""
32
-
33
- # User param -> (client param, default value)
34
- PARAMS = {
35
- "engine": ("model", "text-davinci-003"),
36
- "temperature": ("temperature", 1.0),
37
- "max_tokens": ("max_tokens", 10),
38
- "n": ("n", 1),
39
- "top_p": ("top_p", 1.0),
40
- "top_k": ("best_of", 1),
41
- "logprobs": ("logprobs", None),
42
- "stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
43
- "presence_penalty": ("presence_penalty", 0.0),
44
- "frequency_penalty": ("frequency_penalty", 0.0),
45
- "batch_size": ("batch_size", 20),
46
- }
47
- REQUEST_CLS: Type[Request] = LMRequest
48
- NAME = "openai"
49
-
50
- def connect(
51
- self,
52
- connection_str: Optional[str] = None,
53
- client_args: Dict[str, Any] = {},
54
- ) -> None:
55
- """
56
- Connect to the OpenAI server.
57
-
58
- connection_str is passed as default OPENAI_API_KEY if variable not set.
59
-
60
- Args:
61
- connection_str: connection string.
62
- client_args: client arguments.
63
- """
64
- self.api_key = connection_str or os.environ.get("OPENAI_API_KEY")
65
- if self.api_key is None:
66
- raise ValueError(
67
- "OpenAI API key not set. Set OPENAI_API_KEY environment "
68
- "variable or pass through `client_connection`."
69
- )
70
- self.host = "https://api.openai.com/v1"
71
- for key in self.PARAMS:
72
- setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
73
- if getattr(self, "engine") not in OPENAI_ENGINES:
74
- raise ValueError(
75
- f"Invalid engine {getattr(self, 'engine')}. Must be {OPENAI_ENGINES}."
76
- )
77
-
78
- def close(self) -> None:
79
- """Close the client."""
80
- pass
81
-
82
- def get_generation_url(self) -> str:
83
- """Get generation URL."""
84
- return self.host + "/completions"
85
-
86
- def get_generation_header(self) -> Dict[str, str]:
87
- """
88
- Get generation header.
89
-
90
- Returns:
91
- header.
92
- """
93
- return {"Authorization": f"Bearer {self.api_key}"}
94
-
95
- def supports_batch_inference(self) -> bool:
96
- """Return whether the client supports batch inference."""
97
- return True
98
-
99
- def supports_streaming_inference(self) -> bool:
100
- """Return whether the client supports streaming inference.
101
-
102
- Override in child client class.
103
- """
104
- return True
105
-
106
- def get_model_params(self) -> Dict:
107
- """
108
- Get model params.
109
-
110
- By getting model params from the server, we can add to request
111
- and make sure cache keys are unique to model.
112
-
113
- Returns:
114
- model params.
115
- """
116
- return {"model_name": self.NAME, "engine": getattr(self, "engine")}
117
-
118
- def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
119
- """
120
- Validate response as dict.
121
-
122
- Args:
123
- response: response
124
- request: request
125
-
126
- Return:
127
- response as dict
128
- """
129
- validated_response = super().postprocess_response(response, request)
130
- # Handle logprobs
131
- for choice in validated_response["choices"]:
132
- if "logprobs" in choice:
133
- logprobs = choice.pop("logprobs")
134
- if logprobs and "token_logprobs" in logprobs:
135
- choice["token_logprobs"] = logprobs["token_logprobs"]
136
- choice["tokens"] = logprobs["tokens"]
137
- return validated_response
138
-
139
- def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
140
- """Split usage into list of usages for each prompt."""
141
- try:
142
- encoding = tiktoken.encoding_for_model(getattr(self, "engine"))
143
- except Exception:
144
- return []
145
- prompt = request["prompt"]
146
- # If n > 1 and prompt is a string, we need to split it into a list
147
- if isinstance(prompt, str):
148
- prompts = [prompt] * len(choices)
149
- else:
150
- prompts = prompt
151
- assert len(prompts) == len(choices)
152
- usages = []
153
- for pmt, chc in zip(prompts, choices):
154
- pmt_tokens = len(encoding.encode(pmt))
155
- chc_tokens = len(encoding.encode(chc["text"])) # type: ignore
156
- usage = {
157
- "prompt_tokens": pmt_tokens,
158
- "completion_tokens": chc_tokens,
159
- "total_tokens": pmt_tokens + chc_tokens,
160
- }
161
- usages.append(usage)
162
- return usages