nxphi47 commited on
Commit
8889bbb
·
verified ·
1 Parent(s): 465454c

Upload 40 files

Browse files
Files changed (41) hide show
  1. .gitattributes +1 -0
  2. LICENSE +201 -0
  3. app.py +73 -1729
  4. assets/.DS_Store +0 -0
  5. assets/attention_all_you_need.pdf +0 -0
  6. assets/attention_short.pdf +0 -0
  7. assets/dog_monalisa.jpeg +0 -0
  8. assets/upload_chat.json +10 -0
  9. assets/upload_few_shot.json +10 -0
  10. llama_cpp_requirements.txt +1 -0
  11. mlx_requirements.txt +2 -0
  12. multipurpose_chatbot/.DS_Store +0 -0
  13. multipurpose_chatbot/__init__.py +0 -0
  14. multipurpose_chatbot/configs.py +140 -0
  15. multipurpose_chatbot/demos/.DS_Store +0 -0
  16. multipurpose_chatbot/demos/__init__.py +9 -0
  17. multipurpose_chatbot/demos/base_demo.py +105 -0
  18. multipurpose_chatbot/demos/batch_inference.py +0 -0
  19. multipurpose_chatbot/demos/chat_interface.py +692 -0
  20. multipurpose_chatbot/demos/multimodal_chat_interface.py +1295 -0
  21. multipurpose_chatbot/demos/multimodal_preference_interface.py +794 -0
  22. multipurpose_chatbot/demos/rag_chat_interface.py +638 -0
  23. multipurpose_chatbot/demos/text_completion.py +199 -0
  24. multipurpose_chatbot/engines/.DS_Store +0 -0
  25. multipurpose_chatbot/engines/__init__.py +53 -0
  26. multipurpose_chatbot/engines/base_engine.py +42 -0
  27. multipurpose_chatbot/engines/debug_engine.py +49 -0
  28. multipurpose_chatbot/engines/llama_cpp_engine.py +131 -0
  29. multipurpose_chatbot/engines/llava_llama_cpp_engine.py +280 -0
  30. multipurpose_chatbot/engines/mlx_engine.py +202 -0
  31. multipurpose_chatbot/engines/modeling_sealmm.py +1091 -0
  32. multipurpose_chatbot/engines/sealmmm_engine.py +269 -0
  33. multipurpose_chatbot/engines/transformers_engine.py +454 -0
  34. multipurpose_chatbot/engines/vllm_engine.py +233 -0
  35. multipurpose_chatbot/globals.py +33 -0
  36. pyproject.toml +0 -0
  37. requirements.txt +11 -13
  38. seallm_app.py +1787 -0
  39. seammm_2.png +3 -0
  40. transformers_requirements.txt +1 -0
  41. vllm_requirements.txt +2 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ seammm_2.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py CHANGED
@@ -3,14 +3,15 @@
3
 
4
  # Description:
5
  """
6
- VLLM-based demo script to launch Language chat model for Southeast Asian Languages
7
  """
8
 
9
 
10
  import os
 
11
  import numpy as np
12
  import argparse
13
- import torch
14
  import gradio as gr
15
  from typing import Any, Iterator
16
  from typing import Iterator, List, Optional, Tuple
@@ -29,1759 +30,102 @@ from gradio_client.documentation import document, set_documentation_group
29
  from typing import List, Optional, Union, Dict, Tuple
30
  from tqdm.auto import tqdm
31
  from huggingface_hub import snapshot_download
32
-
33
-
34
- # @@ environments ================
35
-
36
- DEBUG = bool(int(os.environ.get("DEBUG", "1")))
37
-
38
- # List of languages to block
39
- BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
40
- BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
41
-
42
- # for lang block, wether to block in history too
43
- LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
44
- TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
45
- DTYPE = os.environ.get("DTYPE", "bfloat16")
46
-
47
- # ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
48
- DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
49
- LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
50
- # ! show model path in the demo page, only for internal
51
- DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1")))
52
-
53
- # ! uploaded model path, will be downloaded to MODEL_PATH
54
- HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
55
- # ! if model is private, need HF_TOKEN to access the model
56
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
57
- # ! path where the model is downloaded, either on ./ or persistent disc
58
- MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
59
-
60
- # ! log path
61
- LOG_PATH = os.environ.get("LOG_PATH", "").strip()
62
- LOG_FILE = None
63
- SAVE_LOGS = LOG_PATH is not None and LOG_PATH != ''
64
- if SAVE_LOGS:
65
- if os.path.exists(LOG_PATH):
66
- print(f'LOG_PATH exist: {LOG_PATH}')
67
- else:
68
- LOG_DIR = os.path.dirname(LOG_PATH)
69
- os.makedirs(LOG_DIR, exist_ok=True)
70
-
71
- # ! get LOG_PATH as aggregated outputs in log
72
- GET_LOG_CMD = os.environ.get("GET_LOG_CMD", "").strip()
73
-
74
- print(f'SAVE_LOGS: {SAVE_LOGS} | {LOG_PATH}')
75
- # print(f'GET_LOG_CMD: {GET_LOG_CMD}')
76
-
77
- # ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
78
- DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
79
- IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
80
- print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}')
81
-
82
- # ! list of keywords to disabled as security measures to comply with local regulation
83
- KEYWORDS = os.environ.get("KEYWORDS", "").strip()
84
- KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
85
- KEYWORDS = [x.lower() for x in KEYWORDS]
86
-
87
- # bypass
88
- BYPASS_USERS = os.environ.get("BYPASS_USERS", "").strip()
89
- BYPASS_USERS = BYPASS_USERS.split(";") if len(BYPASS_USERS) > 0 else []
90
-
91
- # gradio config
92
- PORT = int(os.environ.get("PORT", "7860"))
93
- # how many iterations to yield response
94
- STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
95
- # how many iterations to perform safety check on response
96
- STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
97
-
98
- # whether to enable to popup accept user
99
- ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
100
-
101
- # self explanatory
102
- MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
103
- TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
104
- FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.1"))
105
- PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
106
- gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
107
-
108
- # whether to enable quantization, currently not in use
109
- QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
110
-
111
-
112
- # Batch inference file upload
113
- ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
114
- BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
115
- BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
116
- BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
117
- BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
118
-
119
- #
120
- DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
121
- DATA_SET_REPO = None
122
-
123
- """
124
- Internal instructions of how to configure the DEMO
125
-
126
- 1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
127
- 2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
128
- 3. space config env: `HF_MODEL_NAME=SeaLLMs/seal-13b-chat-a` or the underlining model
129
- 4. If enable persistent storage: set
130
- HF_HOME=/data/.huggingface
131
- MODEL_PATH=/data/.huggingface/seal-13b-chat-a
132
- if not:
133
- MODEL_PATH=./seal-13b-chat-a
134
-
135
-
136
- HF_HOME=/data/.huggingface
137
- MODEL_PATH=/data/ckpt/seal-13b-chat-a
138
- DELETE_FOLDER=/data/
139
-
140
- """
141
-
142
- # ==============================
143
- print(f'DEBUG mode: {DEBUG}')
144
- print(f'Torch version: {torch.__version__}')
145
- try:
146
- print(f'Torch CUDA version: {torch.version.cuda}')
147
- except Exception as e:
148
- print(f'Failed to print cuda version: {e}')
149
-
150
- try:
151
- compute_capability = torch.cuda.get_device_capability()
152
- print(f'Torch CUDA compute_capability: {compute_capability}')
153
- except Exception as e:
154
- print(f'Failed to print compute_capability version: {e}')
155
-
156
-
157
- # @@ constants ================
158
-
159
- DTYPES = {
160
- 'float16': torch.float16,
161
- 'bfloat16': torch.bfloat16
162
- }
163
-
164
- llm = None
165
- demo = None
166
-
167
-
168
- BOS_TOKEN = '<s>'
169
- EOS_TOKEN = '</s>'
170
-
171
-
172
- SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
173
-
174
-
175
-
176
- # ######### RAG PREPARE
177
- RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
178
-
179
- # RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
180
- RAG_EMBED_MODEL_NAME = "sentence-transformers/LaBSE"
181
-
182
-
183
- def load_embeddings():
184
- global RAG_EMBED
185
- if RAG_EMBED is None:
186
- from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
187
- print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
188
- RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
189
- else:
190
- print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
191
- return RAG_EMBED
192
-
193
-
194
- def get_rag_embeddings():
195
- return load_embeddings()
196
-
197
- _ = get_rag_embeddings()
198
-
199
- RAG_CURRENT_VECTORSTORE = None
200
-
201
- def load_document_split_vectorstore(file_path):
202
- global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
203
- from langchain.text_splitter import RecursiveCharacterTextSplitter
204
- from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
205
- from langchain_community.vectorstores import Chroma, FAISS
206
- from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
207
- # assert RAG_EMBED is not None
208
- splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
209
- if file_path.endswith('.pdf'):
210
- loader = PyPDFLoader(file_path)
211
- elif file_path.endswith('.docx'):
212
- loader = Docx2txtLoader(file_path)
213
- elif file_path.endswith('.txt'):
214
- loader = TextLoader(file_path)
215
- splits = loader.load_and_split(splitter)
216
- RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
217
- return RAG_CURRENT_VECTORSTORE
218
-
219
-
220
- def docs_to_rag_context(docs: List[str]):
221
- contexts = "\n".join([d.page_content for d in docs])
222
- context = f"""Answer the following query exclusively based on the information provided in the document above. \
223
- If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
224
- ###
225
- {contexts}
226
- ###
227
-
228
-
229
- """
230
- return context
231
-
232
- def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
233
- global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
234
- doc_context = None
235
- if file_input is not None:
236
- assert os.path.exists(file_input), f"not found: {file_input}"
237
- if file_input == RAG_CURRENT_FILE:
238
- # reuse
239
- vectorstore = RAG_CURRENT_VECTORSTORE
240
- print(f'Reuse vectorstore: {file_input}')
241
- else:
242
- vectorstore = load_document_split_vectorstore(file_input)
243
- print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
244
- RAG_CURRENT_FILE = file_input
245
- docs = vectorstore.similarity_search(message, k=rag_num_docs)
246
- doc_context = docs_to_rag_context(docs)
247
- return doc_context
248
-
249
- # ######### RAG PREPARE
250
-
251
-
252
- # ============ CONSTANT ============
253
- # https://github.com/gradio-app/gradio/issues/884
254
- MODEL_NAME = "SeaLLM-7B"
255
- MODEL_NAME = str(os.environ.get("MODEL_NAME", "SeaLLM-7B"))
256
-
257
- MODEL_TITLE = """
258
- <div class="container" style="
259
- align-items: center;
260
- justify-content: center;
261
- display: flex;
262
- ">
263
- <div class="image" >
264
- <img src="file/seal_logo.png" style="
265
- max-width: 10em;
266
- max-height: 5%;
267
- height: 3em;
268
- width: 3em;
269
- float: left;
270
- margin-left: auto;
271
- ">
272
- </div>
273
- <div class="text" style="
274
- padding-left: 20px;
275
- padding-top: 1%;
276
- float: left;
277
- ">
278
- <h1 style="font-size: xx-large">SeaLLMs - Large Language Models for Southeast Asia</h1>
279
- </div>
280
- </div>
281
- """
282
-
283
- MODEL_TITLE = """
284
- <img src="file/seal_logo.png" style="
285
- max-width: 10em;
286
- max-height: 5%;
287
- height: 3em;
288
- width: 3em;
289
- ">
290
- <div class="text" style="
291
- loat: left;
292
- padding-bottom: 2%;
293
- ">
294
- SeaLLMs - Large Language Models for Southeast Asia
295
- </div>
296
- """
297
-
298
- """
299
- Somehow cannot add image here
300
- <div class="image" >
301
- <img src="file/seal_logo.png" style="
302
- max-width: 10em;
303
- max-height: 5%;
304
- height: 3em;
305
- width: 3em;
306
- float: left;
307
- margin-left: auto;
308
- ">
309
- </div>
310
- """
311
-
312
- MODEL_DESC = f"""
313
- <div style='display:flex; gap: 0.25rem; '>
314
- <a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
315
- <a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
316
- <a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
317
- <a href='https://arxiv.org/pdf/2312.00738.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
318
- </div>
319
- <span style="font-size: larger">
320
- <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">{MODEL_NAME}-v2</a> - a helpful assistant for Southeast Asian Languages 🇬🇧 🇻🇳 🇮🇩 🇹🇭 🇲🇾 🇰🇭 🇱🇦 🇵🇭 🇲🇲.
321
- Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">our article</a> for more.
322
- </span>
323
- <br>
324
- <span>
325
- <span style="color: red">NOTE: The chatbot may produce false and harmful content and does not have up-to-date knowledge.</span>
326
- By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
327
- not to use our service to generate any harmful, inappropriate or illegal content.
328
- The service collects user dialogue data for testing and improvement under
329
- <a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
330
- </span>
331
- """.strip()
332
-
333
-
334
- cite_markdown = """
335
- ## Citation
336
- If you find our project useful, hope you can star our repo and cite our paper as follows:
337
- ```
338
- @article{damonlpsg2023seallm,
339
- author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
340
- title = {SeaLLMs - Large Language Models for Southeast Asia},
341
- year = 2023,
342
- }
343
- ```
344
- """
345
-
346
- path_markdown = """
347
- #### Model path:
348
- {model_path}
349
- """
350
-
351
-
352
-
353
- # ! ==================================================================
354
-
355
- set_documentation_group("component")
356
-
357
-
358
- RES_PRINTED = False
359
-
360
-
361
- @document()
362
- class ChatBot(gr.Chatbot):
363
- def _postprocess_chat_messages(
364
- self, chat_message
365
- ):
366
- x = super()._postprocess_chat_messages(chat_message)
367
- # if isinstance(x, str):
368
- # x = x.strip().replace("\n", "<br>")
369
- return x
370
-
371
-
372
- from gradio.components import Button
373
  from gradio.events import Dependency, EventListenerMethod
374
 
375
- # replace events so that submit button is disabled during generation, if stop_btn not found
376
- # this prevent weird behavior
377
- def _setup_stop_events(
378
- self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
379
- ) -> None:
380
- from gradio.components import State
381
- event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
382
- if self.stop_btn and self.is_generator:
383
- if self.submit_btn:
384
- for event_trigger in event_triggers:
385
- event_trigger(
386
- lambda: (
387
- Button(visible=False),
388
- Button(visible=True),
389
- ),
390
- None,
391
- [self.submit_btn, self.stop_btn],
392
- api_name=False,
393
- queue=False,
394
- )
395
- event_to_cancel.then(
396
- lambda: (Button(visible=True), Button(visible=False)),
397
- None,
398
- [self.submit_btn, self.stop_btn],
399
- api_name=False,
400
- queue=False,
401
- )
402
- else:
403
- for event_trigger in event_triggers:
404
- event_trigger(
405
- lambda: Button(visible=True),
406
- None,
407
- [self.stop_btn],
408
- api_name=False,
409
- queue=False,
410
- )
411
- event_to_cancel.then(
412
- lambda: Button(visible=False),
413
- None,
414
- [self.stop_btn],
415
- api_name=False,
416
- queue=False,
417
- )
418
- self.stop_btn.click(
419
- None,
420
- None,
421
- None,
422
- cancels=event_to_cancel,
423
- api_name=False,
424
- )
425
- else:
426
- if self.submit_btn:
427
- for event_trigger in event_triggers:
428
- event_trigger(
429
- lambda: Button(interactive=False),
430
- None,
431
- [self.submit_btn],
432
- api_name=False,
433
- queue=False,
434
- )
435
- event_to_cancel.then(
436
- lambda: Button(interactive=True),
437
- None,
438
- [self.submit_btn],
439
- api_name=False,
440
- queue=False,
441
- )
442
- # upon clear, cancel the submit event as well
443
- if self.clear_btn:
444
- self.clear_btn.click(
445
- lambda: ([], [], None, Button(interactive=True)),
446
- None,
447
- [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
448
- queue=False,
449
- api_name=False,
450
- cancels=event_to_cancel,
451
- )
452
-
453
- # TODO: reconfigure clear button as stop and clear button
454
- def _setup_events(self) -> None:
455
- from gradio.components import State
456
- has_on = False
457
- try:
458
- from gradio.events import Dependency, EventListenerMethod, on
459
- has_on = True
460
- except ImportError as ie:
461
- has_on = False
462
- submit_fn = self._stream_fn if self.is_generator else self._submit_fn
463
-
464
- def update_time(c_time, chatbot_state):
465
- # if chatbot_state is empty, register a new conversaion with the current timestamp
466
- # assert len(chatbot_state) > 0, f'empty chatbot state'
467
- if len(chatbot_state) <= 1:
468
- return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
469
- # elif len(chatbot_state) == 1:
470
- # # assert chatbot_state[-1][-1] is None, f'invalid [[message, None]] , got {chatbot_state}'
471
- # return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
472
- else:
473
- return c_time, chatbot_state
474
-
475
- if has_on:
476
- # new version
477
- submit_triggers = (
478
- [self.textbox.submit, self.submit_btn.click]
479
- if self.submit_btn
480
- else [self.textbox.submit]
481
- )
482
- submit_event = (
483
- on(
484
- submit_triggers,
485
- self._clear_and_save_textbox,
486
- [self.textbox],
487
- [self.textbox, self.saved_input],
488
- api_name=False,
489
- queue=False,
490
- )
491
- .then(
492
- self._display_input,
493
- [self.saved_input, self.chatbot_state],
494
- [self.chatbot, self.chatbot_state],
495
- api_name=False,
496
- queue=False,
497
- )
498
- .then(
499
- update_time,
500
- [self.additional_inputs[-1], self.chatbot_state],
501
- [self.additional_inputs[-1], self.chatbot_state],
502
- api_name=False,
503
- queue=False,
504
- )
505
- .then(
506
- submit_fn,
507
- [self.saved_input, self.chatbot_state] + self.additional_inputs,
508
- [self.chatbot, self.chatbot_state],
509
- api_name=False,
510
- )
511
- )
512
- self._setup_stop_events(submit_triggers, submit_event)
513
- else:
514
- raise ValueError(f'Better install new gradio version than 3.44.0')
515
-
516
- if self.retry_btn:
517
- retry_event = (
518
- self.retry_btn.click(
519
- self._delete_prev_fn,
520
- [self.chatbot_state],
521
- [self.chatbot, self.saved_input, self.chatbot_state],
522
- api_name=False,
523
- queue=False,
524
- )
525
- .then(
526
- self._display_input,
527
- [self.saved_input, self.chatbot_state],
528
- [self.chatbot, self.chatbot_state],
529
- api_name=False,
530
- queue=False,
531
- )
532
- .then(
533
- submit_fn,
534
- [self.saved_input, self.chatbot_state] + self.additional_inputs,
535
- [self.chatbot, self.chatbot_state],
536
- api_name=False,
537
- )
538
- )
539
- self._setup_stop_events([self.retry_btn.click], retry_event)
540
-
541
- if self.undo_btn:
542
- self.undo_btn.click(
543
- self._delete_prev_fn,
544
- [self.chatbot_state],
545
- [self.chatbot, self.saved_input, self.chatbot_state],
546
- api_name=False,
547
- queue=False,
548
- ).then(
549
- lambda x: x,
550
- [self.saved_input],
551
- [self.textbox],
552
- api_name=False,
553
- queue=False,
554
- )
555
-
556
- # Reconfigure clear_btn to stop and clear text box
557
-
558
-
559
- def _display_input(
560
- self, message: str, history: List[List[Union[str, None]]]
561
- ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
562
- if message is not None and message.strip() != "":
563
- history.append([message, None])
564
- return history, history
565
-
566
-
567
- async def _stream_fn(
568
- self,
569
- message: str,
570
- history_with_input,
571
- request: Request,
572
- *args,
573
- ) -> AsyncGenerator:
574
- history = history_with_input[:-1]
575
- inputs, _, _ = special_args(
576
- self.fn, inputs=[message, history, *args], request=request
577
- )
578
-
579
- if self.is_async:
580
- generator = self.fn(*inputs)
581
- else:
582
- generator = await anyio.to_thread.run_sync(
583
- self.fn, *inputs, limiter=self.limiter
584
- )
585
- generator = SyncToAsyncIterator(generator, self.limiter)
586
- try:
587
- first_response = await async_iteration(generator)
588
- update = history + [[message, first_response]]
589
- yield update, update
590
- except StopIteration:
591
- update = history + [[message, None]]
592
- yield update, update
593
- except Exception as e:
594
- yield history, history
595
- raise e
596
-
597
- try:
598
- async for response in generator:
599
- update = history + [[message, response]]
600
- yield update, update
601
- except Exception as e:
602
- # if "invalid" in str(e):
603
- # yield history, history
604
- # raise e
605
- # else:
606
- # raise e
607
- yield history, history
608
- raise e
609
-
610
-
611
-
612
-
613
- # replace
614
- gr.ChatInterface._setup_stop_events = _setup_stop_events
615
- gr.ChatInterface._setup_events = _setup_events
616
- gr.ChatInterface._display_input = _display_input
617
- gr.ChatInterface._stream_fn = _stream_fn
618
-
619
-
620
- @document()
621
- class CustomTabbedInterface(gr.Blocks):
622
- def __init__(
623
- self,
624
- interface_list: list[gr.Interface],
625
- tab_names: Optional[list[str]] = None,
626
- title: Optional[str] = None,
627
- description: Optional[str] = None,
628
- theme: Optional[gr.Theme] = None,
629
- analytics_enabled: Optional[bool] = None,
630
- css: Optional[str] = None,
631
- ):
632
- """
633
- Parameters:
634
- interface_list: a list of interfaces to be rendered in tabs.
635
- tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
636
- title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
637
- analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
638
- css: custom css or path to custom css file to apply to entire Blocks
639
- Returns:
640
- a Gradio Tabbed Interface for the given interfaces
641
- """
642
- super().__init__(
643
- title=title or "Gradio",
644
- theme=theme,
645
- analytics_enabled=analytics_enabled,
646
- mode="tabbed_interface",
647
- css=css,
648
- )
649
- self.description = description
650
- if tab_names is None:
651
- tab_names = [f"Tab {i}" for i in range(len(interface_list))]
652
- with self:
653
- if title:
654
- gr.Markdown(
655
- f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
656
- )
657
- if description:
658
- gr.Markdown(description)
659
- with gr.Tabs():
660
- for interface, tab_name in zip(interface_list, tab_names):
661
- with gr.Tab(label=tab_name):
662
- interface.render()
663
-
664
-
665
- def vllm_abort(self):
666
- sh = self.llm_engine.scheduler
667
- for g in (sh.waiting + sh.running + sh.swapped):
668
- sh.abort_seq_group(g.request_id)
669
- from vllm.sequence import SequenceStatus
670
- scheduler = self.llm_engine.scheduler
671
- for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
672
- for seq_group in state_queue:
673
- # if seq_group.request_id == request_id:
674
- # Remove the sequence group from the state queue.
675
- state_queue.remove(seq_group)
676
- for seq in seq_group.seqs:
677
- if seq.is_finished():
678
- continue
679
- scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
680
-
681
-
682
- def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
683
- from vllm.outputs import RequestOutput
684
- # Initialize tqdm.
685
- if use_tqdm:
686
- num_requests = self.llm_engine.get_num_unfinished_requests()
687
- pbar = tqdm(total=num_requests, desc="Processed prompts")
688
- # Run the engine.
689
- outputs: Dict[str, RequestOutput] = {}
690
- while self.llm_engine.has_unfinished_requests():
691
- step_outputs = self.llm_engine.step()
692
- for output in step_outputs:
693
- outputs[output.request_id] = output
694
- if len(outputs) > 0:
695
- yield outputs
696
-
697
-
698
-
699
- def vllm_generate_stream(
700
- self: Any,
701
- prompts: Optional[Union[str, List[str]]] = None,
702
- sampling_params: Optional[Any] = None,
703
- prompt_token_ids: Optional[List[List[int]]] = None,
704
- use_tqdm: bool = False,
705
- ) -> Dict[str, Any]:
706
- """Generates the completions for the input prompts.
707
-
708
- NOTE: This class automatically batches the given prompts, considering
709
- the memory constraint. For the best performance, put all of your prompts
710
- into a single list and pass it to this method.
711
-
712
- Args:
713
- prompts: A list of prompts to generate completions for.
714
- sampling_params: The sampling parameters for text generation. If
715
- None, we use the default sampling parameters.
716
- prompt_token_ids: A list of token IDs for the prompts. If None, we
717
- use the tokenizer to convert the prompts to token IDs.
718
- use_tqdm: Whether to use tqdm to display the progress bar.
719
-
720
- Returns:
721
- A list of `RequestOutput` objects containing the generated
722
- completions in the same order as the input prompts.
723
- """
724
- from vllm import LLM, SamplingParams
725
- if prompts is None and prompt_token_ids is None:
726
- raise ValueError("Either prompts or prompt_token_ids must be "
727
- "provided.")
728
- if isinstance(prompts, str):
729
- # Convert a single prompt to a list.
730
- prompts = [prompts]
731
- if prompts is not None and prompt_token_ids is not None:
732
- if len(prompts) != len(prompt_token_ids):
733
- raise ValueError("The lengths of prompts and prompt_token_ids "
734
- "must be the same.")
735
- if sampling_params is None:
736
- # Use default sampling params.
737
- sampling_params = SamplingParams()
738
-
739
- # Add requests to the engine.
740
- if prompts is not None:
741
- num_requests = len(prompts)
742
- else:
743
- num_requests = len(prompt_token_ids)
744
- for i in range(num_requests):
745
- prompt = prompts[i] if prompts is not None else None
746
- if prompt_token_ids is None:
747
- token_ids = None
748
- else:
749
- token_ids = prompt_token_ids[i]
750
- self._add_request(prompt, sampling_params, token_ids)
751
- # return self._run_engine(use_tqdm)
752
- yield from _vllm_run_engine(self, use_tqdm)
753
-
754
-
755
-
756
- # ! avoid saying
757
- # LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \
758
- # Please also consider clearing the chat box for a better experience."""
759
-
760
- # KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
761
-
762
- LANG_BLOCK_MESSAGE = """Unsupported language."""
763
-
764
- KEYWORD_BLOCK_MESSAGE = "Invalid request."
765
-
766
-
767
- def _detect_lang(text):
768
- # Disable language that may have safety risk
769
- from langdetect import detect as detect_lang
770
- dlang = None
771
- try:
772
- dlang = detect_lang(text)
773
- except Exception as e:
774
- if "No features in text." in str(e):
775
- return "en"
776
- else:
777
- return "zh"
778
- return dlang
779
-
780
-
781
- def block_lang(
782
- message: str,
783
- history: List[Tuple[str, str]] = None,
784
- ) -> str:
785
- # relieve history base block
786
- if len(BLOCK_LANGS) == 0:
787
- return False
788
-
789
- if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
790
- return True
791
- else:
792
- _lang = _detect_lang(message)
793
- if _lang in BLOCK_LANGS:
794
- print(f'Detect blocked {_lang}: {message}')
795
- return True
796
- else:
797
- return False
798
-
799
-
800
- def safety_check(text, history=None, ) -> Optional[str]:
801
- """
802
- Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
803
- This provides an additional security measure to enhance safety and compliance with local regulations.
804
- """
805
- if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
806
- return KEYWORD_BLOCK_MESSAGE
807
-
808
- if len(BLOCK_LANGS) > 0:
809
- if block_lang(text, history):
810
- return LANG_BLOCK_MESSAGE
811
-
812
- return None
813
-
814
-
815
-
816
- TURN_TEMPLATE = "<|im_start|>{role}\n{content}</s>"
817
- TURN_PREFIX = "<|im_start|>{role}\n"
818
-
819
-
820
- def chatml_chat_convo_format(conversations, add_assistant_prefix: bool, default_system=SYSTEM_PROMPT_1):
821
- if conversations[0]['role'] != 'system':
822
- conversations = [{"role": "system", "content": default_system}] + conversations
823
- text = ''
824
- for turn_id, turn in enumerate(conversations):
825
- prompt = TURN_TEMPLATE.format(role=turn['role'], content=turn['content'])
826
- text += prompt
827
- if add_assistant_prefix:
828
- prompt = TURN_PREFIX.format(role='assistant')
829
- text += prompt
830
- return text
831
-
832
-
833
- def chatml_format(message, history=None, system_prompt=None):
834
- conversations = []
835
- system_prompt = system_prompt or "You are a helpful assistant."
836
- if history is not None and len(history) > 0:
837
- for i, (prompt, res) in enumerate(history):
838
- conversations.append({"role": "user", "content": prompt.strip()})
839
- conversations.append({"role": "assistant", "content": res.strip()})
840
- conversations.append({"role": "user", "content": message.strip()})
841
- return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
842
-
843
-
844
- def debug_chat_response_stream_multiturn(message, history):
845
- message_safety = safety_check(message, history=history)
846
- if message_safety is not None:
847
- # yield message_safety
848
- raise gr.Error(message_safety)
849
-
850
- message = "This is a debugging message"
851
- for i in range(len(message)):
852
- time.sleep(0.05)
853
- yield message[:i]
854
-
855
-
856
-
857
- def chat_response_stream_multiturn(
858
- message: str,
859
- history: List[Tuple[str, str]],
860
- temperature: float,
861
- max_tokens: int,
862
- frequency_penalty: float,
863
- presence_penalty: float,
864
- system_prompt: Optional[str] = SYSTEM_PROMPT_1,
865
- current_time: Optional[float] = None,
866
- # profile: Optional[gr.OAuthProfile] = None,
867
- ) -> str:
868
- """
869
- gr.Number(value=temperature, label='Temperature (higher -> more random)'),
870
- gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
871
- gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
872
- gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
873
- gr.Textbox(value=sys_prompt, label='System prompt', lines=8, interactive=False),
874
- gr.Number(value=0, label='current_time', visible=False),
875
- """
876
- global LOG_FILE, LOG_PATH
877
- if DEBUG:
878
- yield from debug_chat_response_stream_multiturn(message, history)
879
- return
880
- from vllm import LLM, SamplingParams
881
- """Build multi turn
882
-
883
- message is incoming prompt
884
- history don't have the current messauge
885
- """
886
- global llm, RES_PRINTED
887
- assert llm is not None
888
- assert system_prompt.strip() != '', f'system prompt is empty'
889
- # is_by_pass = False if profile is None else profile.username in BYPASS_USERS
890
- is_by_pass = False
891
-
892
- tokenizer = llm.get_tokenizer()
893
- # force removing all
894
- vllm_abort(llm)
895
-
896
- temperature = float(temperature)
897
- frequency_penalty = float(frequency_penalty)
898
- max_tokens = int(max_tokens)
899
-
900
- message = message.strip()
901
 
902
- if GET_LOG_CMD != "" and message.strip() == GET_LOG_CMD:
903
- print_log_file()
904
- yield "Finish printed log. Please clear the chatbox now."
905
- return
 
 
 
 
 
 
 
 
 
 
906
 
907
- if len(message) == 0:
908
- raise gr.Error("The message cannot be empty!")
909
 
910
- message_safety = safety_check(message, history=history)
911
- if message_safety is not None and not is_by_pass:
912
- # yield message_safety
913
- raise gr.Error(message_safety)
914
-
915
- # history will be appended with message later on
916
-
917
- full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
918
- print(full_prompt)
919
-
920
- if len(tokenizer.encode(full_prompt)) >= 4050:
921
- raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
922
-
923
- sampling_params = SamplingParams(
924
- temperature=temperature,
925
- max_tokens=max_tokens,
926
- frequency_penalty=frequency_penalty,
927
- presence_penalty=presence_penalty,
928
- # stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'],
929
- stop=['<s>', '</s>', '<|im_start|>', '<|im_end|>'],
930
- )
931
- cur_out = None
932
-
933
- for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
934
- if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
935
- # cur_out = cur_out.replace("\\n", "\n")
936
-
937
- # optionally check safety, and respond
938
- if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
939
- message_safety = safety_check(cur_out, history=None)
940
- if message_safety is not None and not is_by_pass:
941
- # yield message_safety
942
- raise gr.Error(message_safety)
943
- # return
944
-
945
- yield cur_out
946
- assert len(gen) == 1, f'{gen}'
947
- item = next(iter(gen.values()))
948
- cur_out = item.outputs[0].text
949
- #cur_out = "Our system is under maintenance, will be back soon!"
950
- if j >= max_tokens - 2:
951
- gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
952
-
953
- # TODO: use current_time to register conversations, accoriding history and cur_out
954
- history_str = format_conversation(history + [[message, cur_out]])
955
- print(f'@@@@@@@@@@\n{history_str}\n##########\n')
956
-
957
- maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
958
-
959
- if cur_out is not None and "\\n" in cur_out:
960
- print(f'double slash-n in cur_out:\n{cur_out}')
961
- cur_out = cur_out.replace("\\n", "\n")
962
-
963
- if cur_out is not None:
964
- yield cur_out
965
-
966
- message_safety = safety_check(cur_out, history=None)
967
- if message_safety is not None and not is_by_pass:
968
- # yield message_safety
969
- raise gr.Error(message_safety)
970
- # return
971
-
972
-
973
-
974
- def chat_response_stream_rag_multiturn(
975
- message: str,
976
- history: List[Tuple[str, str]],
977
- file_input: str,
978
- temperature: float,
979
- max_tokens: int,
980
- # frequency_penalty: float,
981
- # presence_penalty: float,
982
- system_prompt: Optional[str] = SYSTEM_PROMPT_1,
983
- current_time: Optional[float] = None,
984
- rag_num_docs: Optional[int] = 3,
985
- ):
986
- message = message.strip()
987
- frequency_penalty = FREQUENCE_PENALTY
988
- presence_penalty = PRESENCE_PENALTY
989
- if len(message) == 0:
990
- raise gr.Error("The message cannot be empty!")
991
- doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
992
- if doc_context is not None:
993
- message = f"{doc_context}\n\n{message}"
994
- yield from chat_response_stream_multiturn(
995
- message, history, temperature, max_tokens, frequency_penalty,
996
- presence_penalty, system_prompt, current_time
997
- )
998
-
999
-
1000
- def debug_generate_free_form_stream(message):
1001
- output = " This is a debugging message...."
1002
- for i in range(len(output)):
1003
- time.sleep(0.05)
1004
- yield message + output[:i]
1005
-
1006
-
1007
- def generate_free_form_stream(
1008
- message: str,
1009
- temperature: float,
1010
- max_tokens: int,
1011
- frequency_penalty: float,
1012
- presence_penalty: float,
1013
- stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
1014
- current_time: Optional[float] = None,
1015
- ) -> str:
1016
- global LOG_FILE, LOG_PATH
1017
- if DEBUG:
1018
- yield from debug_generate_free_form_stream(message)
1019
- return
1020
- from vllm import LLM, SamplingParams
1021
- """Build multi turn
1022
- """
1023
- global llm, RES_PRINTED
1024
- assert llm is not None
1025
- tokenizer = llm.get_tokenizer()
1026
- # force removing all
1027
- vllm_abort(llm)
1028
-
1029
- temperature = float(temperature)
1030
- frequency_penalty = float(frequency_penalty)
1031
- max_tokens = int(max_tokens)
1032
-
1033
- stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
1034
- stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
1035
-
1036
- sampling_params = SamplingParams(
1037
- temperature=temperature,
1038
- max_tokens=max_tokens,
1039
- frequency_penalty=frequency_penalty,
1040
- presence_penalty=presence_penalty,
1041
- stop=stop_strings,
1042
- # ignore_eos=True,
1043
- )
1044
-
1045
- # full_prompt = message
1046
- if len(message) == 0:
1047
- raise gr.Error("The message cannot be empty!")
1048
-
1049
- message_safety = safety_check(message)
1050
- if message_safety is not None:
1051
- raise gr.Error(message_safety)
1052
-
1053
- if len(tokenizer.encode(message)) >= 4050:
1054
- raise gr.Error(f"Prompt is too long!")
1055
-
1056
- cur_out = None
1057
- for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
1058
- if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
1059
- # optionally check safety, and respond
1060
- if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
1061
- message_safety = safety_check(cur_out, history=None)
1062
- if message_safety is not None:
1063
- raise gr.Error(message_safety)
1064
- yield message + cur_out
1065
- assert len(gen) == 1, f'{gen}'
1066
- item = next(iter(gen.values()))
1067
- cur_out = item.outputs[0].text
1068
- #cur_out = "Our system is under maintenance, will be back soon!"
1069
- if j >= max_tokens - 2:
1070
- gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
1071
-
1072
- if cur_out is not None:
1073
- yield message + cur_out
1074
-
1075
- message_safety = safety_check(message + cur_out, history=None)
1076
- if message_safety is not None:
1077
- raise gr.Error(message_safety)
1078
-
1079
-
1080
-
1081
-
1082
- def maybe_log_conv_file(current_time, history, message, response, **kwargs):
1083
- global LOG_FILE
1084
- if LOG_FILE is not None:
1085
- my_history = history + [[message, response]]
1086
- obj = {
1087
- 'key': str(current_time),
1088
- 'history': my_history
1089
- }
1090
- for k, v in kwargs.items():
1091
- obj[k] = v
1092
- log_ = json.dumps(obj, ensure_ascii=False)
1093
- LOG_FILE.write(log_ + "\n")
1094
- LOG_FILE.flush()
1095
- print(f'Wrote {obj["key"]} to {LOG_PATH}')
1096
-
1097
-
1098
- def format_conversation(history):
1099
- _str = '\n'.join([
1100
- (
1101
- f'<<<User>>> {h[0]}\n'
1102
- f'<<<Asst>>> {h[1]}'
1103
- )
1104
- for h in history
1105
- ])
1106
- return _str
1107
-
1108
-
1109
- def aggregate_convos():
1110
- from datetime import datetime
1111
- global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1112
- assert os.path.exists(LOG_PATH), f'{LOG_PATH} not found'
1113
- convos = None
1114
- irregular_count = 1
1115
- with open(LOG_PATH, 'r', encoding='utf-8') as f:
1116
- convos = {}
1117
- for i, l in enumerate(f):
1118
- if l:
1119
- item = json.loads(l)
1120
- key = item['key']
1121
- try:
1122
- key = float(key)
1123
- except Exception as e:
1124
- key = -1
1125
- if key > 0.0:
1126
- item_key = datetime.fromtimestamp(key).strftime("%Y-%m-%d %H:%M:%S")
1127
- else:
1128
- key = item_key = f'e{irregular_count}'
1129
- irregular_count += 1
1130
- item['key'] = item_key
1131
- convos[key] = item
1132
- return convos
1133
-
1134
- def maybe_upload_to_dataset():
1135
- from datetime import datetime
1136
- global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1137
- if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
1138
- convos = aggregate_convos()
1139
- AGG_LOG_PATH = LOG_PATH + ".agg.json"
1140
- with open(AGG_LOG_PATH, 'w', encoding='utf-8') as fo:
1141
- json.dump(convos, fo, indent=4, ensure_ascii=False)
1142
- print(f'Saved aggregated json to {AGG_LOG_PATH}')
1143
- try:
1144
- from huggingface_hub import upload_file
1145
- print(f'upload {AGG_LOG_PATH} to {DATA_SET_REPO_PATH}')
1146
- upload_file(
1147
- path_or_fileobj=AGG_LOG_PATH,
1148
- path_in_repo=os.path.basename(AGG_LOG_PATH),
1149
- repo_id=DATA_SET_REPO_PATH,
1150
- token=HF_TOKEN,
1151
- repo_type="dataset",
1152
- create_pr=True
1153
- )
1154
- except Exception as e:
1155
- print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1156
-
1157
-
1158
- def print_log_file():
1159
- global LOG_FILE, LOG_PATH
1160
- if SAVE_LOGS and os.path.exists(LOG_PATH):
1161
- with open(LOG_PATH, 'r', encoding='utf-8') as f:
1162
- convos = aggregate_convos()
1163
- print(f'Printing log from {LOG_PATH}')
1164
- items = list(convos.items())
1165
- for k, v in items[-10:]:
1166
- history = v.pop('history')
1167
- print(f'######--{v}--#####')
1168
- _str = format_conversation(history)
1169
- print(_str)
1170
- maybe_upload_to_dataset()
1171
-
1172
-
1173
- def debug_chat_response_echo(
1174
- message: str,
1175
- history: List[Tuple[str, str]],
1176
- temperature: float = 0.0,
1177
- max_tokens: int = 4096,
1178
- frequency_penalty: float = 0.4,
1179
- presence_penalty: float = 0.0,
1180
- current_time: Optional[float] = None,
1181
- system_prompt: str = SYSTEM_PROMPT_1,
1182
- ) -> str:
1183
- global LOG_FILE
1184
- import time
1185
- time.sleep(0.5)
1186
-
1187
- if message.strip() == GET_LOG_CMD:
1188
- print_log_file()
1189
- yield "Finish printed log."
1190
- return
1191
-
1192
- for i in range(len(message)):
1193
- yield f"repeat: {current_time} {message[:i + 1]}"
1194
-
1195
- cur_out = f"repeat: {current_time} {message}"
1196
- maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
1197
-
1198
-
1199
- def check_model_path(model_path) -> str:
1200
- assert os.path.exists(model_path), f'{model_path} not found'
1201
- ckpt_info = "None"
1202
- if os.path.isdir(model_path):
1203
- if os.path.exists(f'{model_path}/info.txt'):
1204
- with open(f'{model_path}/info.txt', 'r') as f:
1205
- ckpt_info = f.read()
1206
- print(f'Checkpoint info:\n{ckpt_info}\n-----')
1207
- else:
1208
- print(f'info.txt not found in {model_path}')
1209
- print(f'model path dir: {list(os.listdir(model_path))}')
1210
-
1211
- return ckpt_info
1212
-
1213
-
1214
- def maybe_delete_folder():
1215
- if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
1216
- import shutil
1217
- print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
1218
- for filename in os.listdir(DELETE_FOLDER):
1219
- file_path = os.path.join(DELETE_FOLDER, filename)
1220
- try:
1221
- if os.path.isfile(file_path) or os.path.islink(file_path):
1222
- os.unlink(file_path)
1223
- elif os.path.isdir(file_path):
1224
- shutil.rmtree(file_path)
1225
- except Exception as e:
1226
- print('Failed to delete %s. Reason: %s' % (file_path, e))
1227
-
1228
-
1229
- AGREE_POP_SCRIPTS = """
1230
- async () => {
1231
- alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
1232
- }
1233
- """
1234
-
1235
- def debug_file_function(
1236
- files: Union[str, List[str]],
1237
- prompt_mode: str,
1238
- temperature: float,
1239
- max_tokens: int,
1240
- frequency_penalty: float,
1241
- presence_penalty: float,
1242
- stop_strings: str = "[STOP],<s>,</s>",
1243
- current_time: Optional[float] = None,
1244
- ):
1245
- """This is only for debug purpose"""
1246
- files = files if isinstance(files, list) else [files]
1247
- print(files)
1248
- filenames = [f.name for f in files]
1249
- all_items = []
1250
- for fname in filenames:
1251
- print(f'Reading {fname}')
1252
- with open(fname, 'r', encoding='utf-8') as f:
1253
- items = json.load(f)
1254
- assert isinstance(items, list), f'invalid items from {fname} not list'
1255
- all_items.extend(items)
1256
- print(all_items)
1257
- print(f'{prompt_mode} / {temperature} / {max_tokens}, {frequency_penalty}, {presence_penalty}')
1258
- save_path = "./test.json"
1259
- with open(save_path, 'w', encoding='utf-8') as f:
1260
- json.dump(all_items, f, indent=4, ensure_ascii=False)
1261
-
1262
- for x in all_items:
1263
- x['response'] = "Return response"
1264
-
1265
- print_items = all_items[:1]
1266
- # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1267
- return save_path, print_items
1268
-
1269
-
1270
- def validate_file_item(filename, index, item: Dict[str, str]):
1271
- """
1272
- check safety for items in files
1273
- """
1274
- message = item['prompt'].strip()
1275
-
1276
- if len(message) == 0:
1277
- raise gr.Error(f'Prompt {index} empty')
1278
-
1279
- message_safety = safety_check(message, history=None)
1280
- if message_safety is not None:
1281
- raise gr.Error(f'Prompt {index} invalid: {message_safety}')
1282
-
1283
- tokenizer = llm.get_tokenizer() if llm is not None else None
1284
- if tokenizer is None or len(tokenizer.encode(message)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
1285
- raise gr.Error(f"Prompt {index} too long, should be less than {BATCH_INFER_MAX_PROMPT_TOKENS} tokens")
1286
-
1287
-
1288
- def read_validate_json_files(files: Union[str, List[str]]):
1289
- files = files if isinstance(files, list) else [files]
1290
- filenames = [f.name for f in files]
1291
- all_items = []
1292
- for fname in filenames:
1293
- # check each files
1294
- print(f'Reading {fname}')
1295
- with open(fname, 'r', encoding='utf-8') as f:
1296
- items = json.load(f)
1297
- assert isinstance(items, list), f'Data {fname} not list'
1298
- assert all(isinstance(x, dict) for x in items), f'item in input file not list'
1299
- assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
1300
-
1301
- for i, x in enumerate(items):
1302
- validate_file_item(fname, i, x)
1303
 
1304
- all_items.extend(items)
1305
 
1306
- if len(all_items) > BATCH_INFER_MAX_ITEMS:
1307
- raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1308
-
1309
- return all_items, filenames
1310
 
1311
-
1312
- def remove_gradio_cache(exclude_names=None):
1313
- """remove gradio cache to avoid flooding"""
1314
  import shutil
1315
- for root, dirs, files in os.walk('/tmp/gradio/'):
1316
- for f in files:
1317
- # if not any(f in ef for ef in except_files):
1318
- if exclude_names is None or not any(ef in f for ef in exclude_names):
1319
- print(f'Remove: {f}')
1320
- os.unlink(os.path.join(root, f))
1321
- # for d in dirs:
1322
- # # if not any(d in ef for ef in except_files):
1323
- # if exclude_names is None or not any(ef in d for ef in exclude_names):
1324
- # print(f'Remove d: {d}')
1325
- # shutil.rmtree(os.path.join(root, d))
1326
-
1327
-
1328
- def maybe_upload_batch_set(pred_json_path):
1329
- global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1330
-
1331
- if SAVE_LOGS and DATA_SET_REPO_PATH != "":
1332
  try:
1333
- from huggingface_hub import upload_file
1334
- path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
1335
- print(f'upload {pred_json_path} to {DATA_SET_REPO_PATH}//{path_in_repo}')
1336
- upload_file(
1337
- path_or_fileobj=pred_json_path,
1338
- path_in_repo=path_in_repo,
1339
- repo_id=DATA_SET_REPO_PATH,
1340
- token=HF_TOKEN,
1341
- repo_type="dataset",
1342
- create_pr=True
1343
- )
1344
  except Exception as e:
1345
- print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1346
-
1347
-
1348
- def free_form_prompt(prompt, history=None, system_prompt=None):
1349
- return prompt
1350
-
1351
- def batch_inference(
1352
- files: Union[str, List[str]],
1353
- prompt_mode: str,
1354
- temperature: float,
1355
- max_tokens: int,
1356
- frequency_penalty: float,
1357
- presence_penalty: float,
1358
- stop_strings: str = "[STOP],<s>,</s>,<|im_start|>",
1359
- current_time: Optional[float] = None,
1360
- system_prompt: Optional[str] = SYSTEM_PROMPT_1
1361
- ):
1362
- """
1363
- Handle file upload batch inference
1364
-
1365
- """
1366
- global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
1367
- if DEBUG:
1368
- return debug_file_function(
1369
- files, prompt_mode, temperature, max_tokens,
1370
- presence_penalty, stop_strings, current_time)
1371
-
1372
- from vllm import LLM, SamplingParams
1373
- assert llm is not None
1374
- # assert system_prompt.strip() != '', f'system prompt is empty'
1375
-
1376
- stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
1377
- tokenizer = llm.get_tokenizer()
1378
- # force removing all
1379
- # NOTE: need to make sure all cached items are removed!!!!!!!!!
1380
- vllm_abort(llm)
1381
-
1382
- temperature = float(temperature)
1383
- frequency_penalty = float(frequency_penalty)
1384
- max_tokens = int(max_tokens)
1385
-
1386
- all_items, filenames = read_validate_json_files(files)
1387
-
1388
- # remove all items in /tmp/gradio/
1389
- remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
1390
-
1391
- if prompt_mode == 'chat':
1392
- prompt_format_fn = chatml_format
1393
- elif prompt_mode == 'few-shot':
1394
- from functools import partial
1395
- # prompt_format_fn = partial(
1396
- # chatml_format, include_end_instruct=False
1397
- # )
1398
- prompt_format_fn = free_form_prompt
1399
- else:
1400
- raise gr.Error(f'Wrong mode {prompt_mode}')
1401
-
1402
- full_prompts = [
1403
- prompt_format_fn(
1404
- x['prompt'], [], sys_prompt=system_prompt
1405
- )
1406
- for i, x in enumerate(all_items)
1407
- ]
1408
- print(f'{full_prompts[0]}\n')
1409
-
1410
- if any(len(tokenizer.encode(x)) >= 4090 for x in full_prompts):
1411
- raise gr.Error(f"Some prompt is too long!")
1412
-
1413
- stop_seq = list(set(['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] + stop_strings))
1414
- sampling_params = SamplingParams(
1415
- temperature=temperature,
1416
- max_tokens=max_tokens,
1417
- frequency_penalty=frequency_penalty,
1418
- presence_penalty=presence_penalty,
1419
- stop=stop_seq
1420
- )
1421
-
1422
- generated = llm.generate(full_prompts, sampling_params, use_tqdm=False)
1423
- responses = [g.outputs[0].text for g in generated]
1424
- #responses = ["Our system is under maintenance, will be back soon!" for g in generated]
1425
- if len(responses) != len(all_items):
1426
- raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
1427
-
1428
- for res, item in zip(responses, all_items):
1429
- item['response'] = res
1430
-
1431
- save_path = BATCH_INFER_SAVE_TMP_FILE
1432
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
1433
- with open(save_path, 'w', encoding='utf-8') as f:
1434
- json.dump(all_items, f, indent=4, ensure_ascii=False)
1435
-
1436
- # You need to upload save_path as a new timestamp file.
1437
- maybe_upload_batch_set(save_path)
1438
-
1439
- print_items = all_items[:2]
1440
- # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1441
- return save_path, print_items
1442
-
1443
-
1444
- # BATCH_INFER_MAX_ITEMS
1445
- FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
1446
- each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
1447
- ```
1448
- [ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
1449
- ```
1450
- """
1451
-
1452
- CHAT_EXAMPLES = [
1453
- ["Hãy giải thích thuyết tương đối rộng."],
1454
- ["Tolong bantu saya menulis email ke lembaga pemerintah untuk mencari dukungan finansial untuk penelitian AI."],
1455
- ["แนะนำ 10 จุดหมายปลายทางในกรุงเทพฯ"],
1456
- ]
1457
-
1458
-
1459
- # performance items
1460
-
1461
- def create_free_form_generation_demo():
1462
- global short_model_path
1463
- max_tokens = MAX_TOKENS
1464
- temperature = TEMPERATURE
1465
- frequence_penalty = FREQUENCE_PENALTY
1466
- presence_penalty = PRESENCE_PENALTY
1467
-
1468
- introduction = """
1469
- ### Free-form | Put any context string (like few-shot prompts)
1470
- """
1471
-
1472
- with gr.Blocks() as demo_free_form:
1473
- gr.Markdown(introduction)
1474
-
1475
- with gr.Row():
1476
- txt = gr.Textbox(
1477
- scale=4,
1478
- lines=16,
1479
- show_label=False,
1480
- placeholder="Enter any free form text and submit",
1481
- container=False,
1482
- )
1483
- with gr.Row():
1484
- free_submit_button = gr.Button('Submit')
1485
- with gr.Row():
1486
- temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
1487
- length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
1488
- freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
1489
- pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
1490
- stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
1491
-
1492
- free_submit_button.click(
1493
- generate_free_form_stream,
1494
- [txt, temp, length, freq_pen, pres_pen, stop_strings],
1495
- txt
1496
- )
1497
- return demo_free_form
1498
-
1499
-
1500
-
1501
- def create_file_upload_demo():
1502
- temperature = TEMPERATURE
1503
- frequence_penalty = FREQUENCE_PENALTY
1504
- presence_penalty = PRESENCE_PENALTY
1505
- max_tokens = MAX_TOKENS
1506
- demo_file_upload = gr.Interface(
1507
- batch_inference,
1508
- inputs=[
1509
- gr.File(file_count='single', file_types=['json']),
1510
- gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1511
- gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
1512
- gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1513
- gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1514
- gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1515
- gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1516
- gr.Number(value=0, label='current_time', visible=False),
1517
- ],
1518
- outputs=[
1519
- # "file",
1520
- gr.File(label="Generated file"),
1521
- # "json"
1522
- gr.JSON(label='Example outputs (display 2 samples)')
1523
- ],
1524
- description=FILE_UPLOAD_DESCRIPTION,
1525
- allow_flagging=False,
1526
- examples=[
1527
- ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
1528
- ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
1529
- ],
1530
- cache_examples=False,
1531
- )
1532
- return demo_file_upload
1533
-
1534
-
1535
- def create_chat_demo(title=None, description=None):
1536
- sys_prompt = SYSTEM_PROMPT_1
1537
- max_tokens = MAX_TOKENS
1538
- temperature = TEMPERATURE
1539
- frequence_penalty = FREQUENCE_PENALTY
1540
- presence_penalty = PRESENCE_PENALTY
1541
-
1542
- demo_chat = gr.ChatInterface(
1543
- chat_response_stream_multiturn,
1544
- chatbot=ChatBot(
1545
- label=MODEL_NAME,
1546
- bubble_full_width=False,
1547
- latex_delimiters=[
1548
- { "left": "$", "right": "$", "display": False},
1549
- { "left": "$$", "right": "$$", "display": True},
1550
- ],
1551
- show_copy_button=True,
1552
- ),
1553
- textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
1554
- submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1555
- # ! consider preventing the stop button
1556
- # stop_btn=None,
1557
- title=title,
1558
- description=description,
1559
- additional_inputs=[
1560
- gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1561
- gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1562
- gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1563
- gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1564
- gr.Textbox(value=sys_prompt, label='System prompt', lines=4, interactive=False),
1565
- gr.Number(value=0, label='current_time', visible=False),
1566
- # ! Remove the system prompt textbox to avoid jailbreaking
1567
- ],
1568
- examples=CHAT_EXAMPLES,
1569
- cache_examples=False
1570
- )
1571
- return demo_chat
1572
-
1573
-
1574
- def upload_file(file):
1575
- # file_paths = [file.name for file in files]
1576
- # return file_paths
1577
- return file.name
1578
-
1579
-
1580
- RAG_DESCRIPTION = """
1581
- * Upload a doc below to answer question about it (RAG).
1582
- * Every question must be explicit and self-contained! Because each prompt will invoke a new RAG retrieval without considering previous conversations.
1583
- (E.g: Dont prompt "Answer my previous question in details.")
1584
- """
1585
-
1586
- def create_chat_demo_rag(title=None, description=None):
1587
- sys_prompt = SYSTEM_PROMPT_1
1588
- max_tokens = MAX_TOKENS
1589
- temperature = TEMPERATURE
1590
- frequence_penalty = FREQUENCE_PENALTY
1591
- presence_penalty = PRESENCE_PENALTY
1592
- description = description or RAG_DESCRIPTION
1593
-
1594
- # with gr.Blocks(title="RAG") as rag_demo:
1595
- additional_inputs = [
1596
- gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
1597
- # gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
1598
- gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1599
- gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1600
- # gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1601
- # gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1602
- gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
1603
- gr.Number(value=0, label='current_time', visible=False),
1604
- ]
1605
-
1606
- demo_rag_chat = gr.ChatInterface(
1607
- chat_response_stream_rag_multiturn,
1608
- chatbot=gr.Chatbot(
1609
- label=MODEL_NAME + "-RAG",
1610
- bubble_full_width=False,
1611
- latex_delimiters=[
1612
- { "left": "$", "right": "$", "display": False},
1613
- { "left": "$$", "right": "$$", "display": True},
1614
- ],
1615
- show_copy_button=True,
1616
- ),
1617
- textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
1618
- submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1619
- # ! consider preventing the stop button
1620
- # stop_btn=None,
1621
- title=title,
1622
- description=description,
1623
- additional_inputs=additional_inputs,
1624
- additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1625
- # examples=CHAT_EXAMPLES,
1626
- cache_examples=False
1627
- )
1628
- # with demo_rag_chat:
1629
- # upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
1630
- # upload_button.upload(upload_file, upload_button, additional_inputs[0])
1631
-
1632
- # return demo_chat
1633
- return demo_rag_chat
1634
-
1635
 
1636
 
1637
  def launch_demo():
1638
- global demo, llm, DEBUG, LOG_FILE
1639
  model_desc = MODEL_DESC
1640
  model_path = MODEL_PATH
1641
- model_title = MODEL_TITLE
1642
- hf_model_name = HF_MODEL_NAME
1643
- tensor_parallel = TENSOR_PARALLEL
1644
- assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
1645
- dtype = DTYPE
1646
- sys_prompt = SYSTEM_PROMPT_1
1647
- max_tokens = MAX_TOKENS
1648
- temperature = TEMPERATURE
1649
- frequence_penalty = FREQUENCE_PENALTY
1650
- presence_penalty = PRESENCE_PENALTY
1651
- ckpt_info = "None"
1652
-
1653
- print(
1654
- f'Launch config: '
1655
- f'\n| model_title=`{model_title}` '
1656
- f'\n| max_tokens={max_tokens} '
1657
- f'\n| dtype={dtype} '
1658
- f'\n| tensor_parallel={tensor_parallel} '
1659
- f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
1660
- f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
1661
- f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
1662
- f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
1663
- f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
1664
- f'\n| frequence_penalty={frequence_penalty} '
1665
- f'\n| presence_penalty={presence_penalty} '
1666
- f'\n| temperature={temperature} '
1667
- # f'\n| hf_model_name={hf_model_name} '
1668
- f'\n| model_path={model_path} '
1669
- f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
1670
- f'\n| gpu_memory_utilization={gpu_memory_utilization} '
1671
- f'\n| LOG_PATH={LOG_PATH} | SAVE_LOGS={SAVE_LOGS} '
1672
- f'\n| Desc={model_desc}'
1673
- )
1674
-
1675
- if DEBUG:
1676
- model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
1677
- # response_fn = debug_chat_response_echo
1678
- response_fn = chat_response_stream_multiturn
1679
- print(f'Creating in DEBUG MODE')
1680
- if SAVE_LOGS:
1681
- LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1682
- else:
1683
- # ! load the model
1684
- maybe_delete_folder()
1685
-
1686
- if DOWNLOAD_SNAPSHOT:
1687
- print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
1688
- if HF_TOKEN is not None:
1689
- print(f'Load with HF_TOKEN: {HF_TOKEN}')
1690
- snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN)
1691
- else:
1692
- snapshot_download(hf_model_name, local_dir=model_path)
1693
-
1694
- import vllm
1695
- from vllm import LLM
1696
-
1697
- print(F'VLLM: {vllm.__version__}')
1698
- ckpt_info = check_model_path(model_path)
1699
-
1700
- print(f'Load path: {model_path} | {ckpt_info}')
1701
 
1702
- if QUANTIZATION == 'awq':
1703
- print(F'Load model in int4 quantization')
1704
- llm = LLM(model=model_path, dtype="float16", tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq", max_model_len=8192)
1705
- else:
1706
- llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, max_model_len=8192)
1707
-
1708
- try:
1709
- print(llm.llm_engine.workers[0].model)
1710
- except Exception as e:
1711
- print(f'Cannot print model worker: {e}')
1712
-
1713
- try:
1714
- llm.llm_engine.scheduler_config.max_model_len = 8192
1715
- llm.llm_engine.scheduler_config.max_num_batched_tokens = 8192
1716
- # llm.llm_engine.tokenizer.add_special_tokens = False
1717
- except Exception as e:
1718
- print(f'Cannot set parameters: {e}')
1719
-
1720
- print(f'Use system prompt:\n{sys_prompt}')
1721
-
1722
- response_fn = chat_response_stream_multiturn
1723
- print(F'respond: {response_fn}')
1724
-
1725
- if SAVE_LOGS:
1726
- LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1727
-
1728
- if ENABLE_BATCH_INFER:
1729
-
1730
- # demo_file_upload = create_file_upload_demo()
1731
-
1732
- demo_free_form = create_free_form_generation_demo()
1733
-
1734
- demo_chat = create_chat_demo()
1735
- demo_chat_rag = create_chat_demo_rag(description=RAG_DESCRIPTION)
1736
- descriptions = model_desc
1737
- if DISPLAY_MODEL_PATH:
1738
- descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1739
-
1740
- demo = CustomTabbedInterface(
1741
- interface_list=[
1742
- demo_chat,
1743
- demo_chat_rag,
1744
- demo_free_form,
1745
- # demo_file_upload,
1746
- ],
1747
- tab_names=[
1748
- "Chat Interface",
1749
- "RAG Chat Interface",
1750
- "Text completion",
1751
- # "Batch Inference",
1752
- ],
1753
- title=f"{model_title}",
1754
- description=descriptions,
1755
  )
1756
- else:
1757
- descriptions = model_desc
1758
- if DISPLAY_MODEL_PATH:
1759
- descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1760
 
1761
- demo = create_chat_demo(title=f"{model_title}", description=descriptions)
 
 
 
 
 
 
1762
  demo.title = MODEL_NAME
1763
 
1764
  with demo:
1765
- if DATA_SET_REPO_PATH != "":
1766
- try:
1767
- from performance_plot import attach_plot_to_demo
1768
- attach_plot_to_demo(demo)
1769
- except Exception as e:
1770
- print(f'Fail to load DEMO plot: {str(e)}')
1771
-
1772
- gr.Markdown(cite_markdown)
1773
- if DISPLAY_MODEL_PATH:
1774
- gr.Markdown(path_markdown.format(model_path=model_path))
1775
 
1776
- if ENABLE_AGREE_POPUP:
1777
- demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1778
-
1779
- # login_btn = gr.LoginButton()
1780
-
1781
  demo.queue(api_open=False)
1782
  return demo
1783
 
1784
 
 
1785
  if __name__ == "__main__":
1786
  demo = launch_demo()
1787
- demo.launch(show_api=False, allowed_paths=["seal_logo.png"])
 
 
 
 
 
 
 
 
3
 
4
  # Description:
5
  """
6
+ Demo script to launch Language chat model
7
  """
8
 
9
 
10
  import os
11
+ from gradio.themes import ThemeClass as Theme
12
  import numpy as np
13
  import argparse
14
+ # import torch
15
  import gradio as gr
16
  from typing import Any, Iterator
17
  from typing import Iterator, List, Optional, Tuple
 
30
  from typing import List, Optional, Union, Dict, Tuple
31
  from tqdm.auto import tqdm
32
  from huggingface_hub import snapshot_download
33
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
34
+ from gradio.components import Button, Component
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  from gradio.events import Dependency, EventListenerMethod
36
 
37
+ from multipurpose_chatbot.demos.base_demo import CustomTabbedInterface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ from multipurpose_chatbot.configs import (
40
+ MODEL_TITLE,
41
+ MODEL_DESC,
42
+ MODEL_INFO,
43
+ CITE_MARKDOWN,
44
+ ALLOWED_PATHS,
45
+ PROXY,
46
+ PORT,
47
+ MODEL_PATH,
48
+ MODEL_NAME,
49
+ BACKEND,
50
+ DEMOS,
51
+ DELETE_FOLDER,
52
+ )
53
 
 
 
54
 
55
+ demo = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
 
 
 
 
 
58
 
59
+ if DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER):
60
+ print(F'WARNING deleting folder: {DELETE_FOLDER}')
 
61
  import shutil
62
+ print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
63
+ for filename in os.listdir(DELETE_FOLDER):
64
+ file_path = os.path.join(DELETE_FOLDER, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  try:
66
+ if os.path.isfile(file_path) or os.path.islink(file_path):
67
+ os.unlink(file_path)
68
+ elif os.path.isdir(file_path):
69
+ shutil.rmtree(file_path)
70
+ print(f'deleted: {file_path}')
 
 
 
 
 
 
71
  except Exception as e:
72
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def launch_demo():
76
+ global demo, MODEL_ENGINE
77
  model_desc = MODEL_DESC
78
  model_path = MODEL_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ print(f'Begin importing models')
81
+ from multipurpose_chatbot.demos import get_demo_class
82
+
83
+ # demos = {
84
+ # k: get_demo_class(k)().create_demo()
85
+ # for k in demo_and_tab_names.keys()
86
+ # }
87
+ print(f'{DEMOS=}')
88
+ demo_class_objects = {
89
+ k: get_demo_class(k)()
90
+ for k in DEMOS
91
+ }
92
+ demos = {
93
+ k: get_demo_class(k)().create_demo()
94
+ for k in DEMOS
95
+ }
96
+ demos_names = [x.tab_name for x in demo_class_objects.values()]
97
+
98
+ descriptions = model_desc
99
+ if MODEL_INFO is not None and MODEL_INFO != "":
100
+ descriptions += (
101
+ f"<br>" +
102
+ MODEL_INFO.format(model_path=model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
 
 
 
 
104
 
105
+ demo = CustomTabbedInterface(
106
+ interface_list=list(demos.values()),
107
+ tab_names=demos_names,
108
+ title=f"{MODEL_TITLE}",
109
+ description=descriptions,
110
+ )
111
+
112
  demo.title = MODEL_NAME
113
 
114
  with demo:
115
+ gr.Markdown(CITE_MARKDOWN)
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
117
  demo.queue(api_open=False)
118
  return demo
119
 
120
 
121
+
122
  if __name__ == "__main__":
123
  demo = launch_demo()
124
+ if PROXY is not None and PROXY != "":
125
+ print(f'{PROXY=} {PORT=}')
126
+ print(f"{ALLOWED_PATHS=}")
127
+ demo.launch(server_port=PORT, root_path=PROXY, show_api=False, allowed_paths=ALLOWED_PATHS)
128
+ else:
129
+ demo.launch(server_port=PORT, show_api=False, allowed_paths=ALLOWED_PATHS)
130
+
131
+
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/attention_all_you_need.pdf ADDED
Binary file (858 kB). View file
 
assets/attention_short.pdf ADDED
Binary file (236 kB). View file
 
assets/dog_monalisa.jpeg ADDED
assets/upload_chat.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "1",
4
+ "prompt": "Tell me something about AI?"
5
+ },
6
+ {
7
+ "id": "2",
8
+ "prompt": "Who are you?"
9
+ }
10
+ ]
assets/upload_few_shot.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "0",
4
+ "prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Mohon diingat bahwa intinya Anda sedang berkunjung ke situs kuburan massal, serta situs yang maknanya tak terhitung bagi sejumlah populasi dunia yang signifikan.\nEnglish:"
5
+ },
6
+ {
7
+ "id": "1",
8
+ "prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Serangga adalah hewan pertama yang menjelajah angkasa. Kemampuan terbangnya membantu mereka menghindari musuh dengan lebih mudah dan mencari makanan dan pasangan dengan lebih efisien.\nEnglish:"
9
+ }
10
+ ]
llama_cpp_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ llama-cpp-python
mlx_requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mlx
2
+ mlx-lm
multipurpose_chatbot/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/__init__.py ADDED
File without changes
multipurpose_chatbot/configs.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ # ! UI Markdown information
5
+
6
+ MODEL_TITLE = """
7
+ <img src="file/seammm_2.png" style="
8
+ max-width: 10em;
9
+ max-height: 5%;
10
+ height: 3em;
11
+ width: 3em;
12
+ ">
13
+ <div class="text" style="
14
+ loat: left;
15
+ padding-bottom: 2%;
16
+ ">
17
+ SeaLMMM - Large Multilingual Multimodal Models for Southeast Asia
18
+ </div>
19
+ """
20
+
21
+ # <a href='https://huggingface.co/spaces/SeaLLMs/SeaLMMM-7b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
22
+ # <a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
23
+ #
24
+ MODEL_DESC = f"""
25
+ <div style='display:flex; gap: 0.25rem; '>
26
+ <a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
27
+ <a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
28
+ <a href='https://huggingface.co/SeaLLMs/SeaLMMM-7B-early'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
29
+ </div>
30
+ <span style="font-size: larger">
31
+ <a href="https://huggingface.co/SeaLLMs/SeaLMMM-7B-early" target="_blank">SeaLMMM-7B-early</a> - multilingual multimodal assistant for Southeast Asia. It handles <b>both</b> text-only (<a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">LLMs</a> and vision instructions (LVMs). <span style="color: red">SeaLMMM-7B has not finished training.</span>
32
+ </span>
33
+ <br>
34
+ <span>
35
+ <span style="color: red">The chatbot may produce false and harmful content!</span>
36
+ By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>
37
+ </span>
38
+
39
+ """.strip()
40
+
41
+ """
42
+ By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
43
+ not to use our service to generate any harmful, inappropriate or illegal content.
44
+ The service collects user dialogue data for testing and improvement under
45
+ <a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
46
+
47
+ """
48
+
49
+
50
+ # MODEL_INFO = """
51
+ # <h4 style="display: hidden;">Model Name: {model_path}</h4>
52
+ # """
53
+ MODEL_INFO = ""
54
+
55
+ CITE_MARKDOWN = """
56
+ ## Citation
57
+ If you find our project useful, hope you can star our repo and cite our paper as follows:
58
+ ```
59
+ @article{damonlpsg2023seallm,
60
+ author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
61
+ title = {SeaLLMs - Large Language Models for Southeast Asia},
62
+ year = 2023,
63
+ }
64
+ ```
65
+
66
+ """
67
+ USE_PANEL = bool(int(os.environ.get("USE_PANEL", "1")))
68
+ CHATBOT_HEIGHT = int(os.environ.get("CHATBOT_HEIGHT", "500"))
69
+
70
+ ALLOWED_PATHS = ["seammm_2.png"]
71
+
72
+
73
+ DEMOS = os.environ.get("DEMOS", "")
74
+
75
+ DEMOS = DEMOS.split(",") if DEMOS.strip() != "" else [
76
+ "DocChatInterfaceDemo",
77
+ "ChatInterfaceDemo",
78
+ "TextCompletionDemo",
79
+ # "RagChatInterfaceDemo",
80
+ # "VisionChatInterfaceDemo",
81
+ # "VisionDocChatInterfaceDemo",
82
+ ]
83
+
84
+ # DEMOS=DocChatInterfaceDemo,ChatInterfaceDemo,RagChatInterfaceDemo,TextCompletionDemo
85
+
86
+
87
+
88
+ # ! server info
89
+
90
+ DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
91
+ PORT = int(os.environ.get("PORT", "7860"))
92
+ PROXY = os.environ.get("PROXY", "").strip()
93
+
94
+ # ! backend info
95
+
96
+ BACKEND = os.environ.get("BACKEND", "debug")
97
+
98
+ # ! model information
99
+ # for RAG
100
+ RAG_EMBED_MODEL_NAME = os.environ.get("RAG_EMBED_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
101
+ CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1024"))
102
+ CHUNK_OVERLAP = int(os.environ.get("CHUNK_SIZE", "50"))
103
+
104
+
105
+ SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", """You are a helpful, respectful, honest and safe AI assistant.""")
106
+
107
+ MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
108
+ TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
109
+ # ! these values currently not used
110
+ FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.0"))
111
+ PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
112
+
113
+
114
+ # Transformers or vllm
115
+ MODEL_PATH = os.environ.get("MODEL_PATH", "mistralai/Mistral-7B-Instruct-v0.2")
116
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Cool-Chatbot")
117
+ DTYPE = os.environ.get("DTYPE", "bfloat16")
118
+ DEVICE = os.environ.get("DEVICE", "cuda")
119
+
120
+ # VLLM
121
+ GPU_MEMORY_UTILIZATION = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
122
+ TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
123
+ QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
124
+ STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
125
+ # how many iterations to perform safety check on response
126
+ STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
127
+
128
+ # llama.cpp
129
+ DEFAULT_CHAT_TEMPLATE = os.environ.get("DEFAULT_CHAT_TEMPLATE", "chatml")
130
+ N_CTX = int(os.environ.get("N_CTX", "4096"))
131
+ N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))
132
+
133
+ # llava.llama.cpp
134
+
135
+
136
+ # Multimodal
137
+ IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "[IMAGE]<|image|>[/IMAGE]")
138
+ IMAGE_TOKEN_INTERACTIVE = bool(int(os.environ.get("IMAGE_TOKEN_INTERACTIVE", "0")))
139
+ IMAGE_TOKEN_LENGTH = int(os.environ.get("IMAGE_TOKEN_LENGTH", "576"))
140
+ MAX_PACHES = int(os.environ.get("MAX_PACHES", "1"))
multipurpose_chatbot/demos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/demos/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .base_demo import *
3
+
4
+ from .chat_interface import ChatInterfaceDemo
5
+ from .rag_chat_interface import RagChatInterfaceDemo
6
+ from .multimodal_chat_interface import *
7
+ from .text_completion import *
8
+ from .batch_inference import *
9
+ from .multimodal_preference_interface import *
multipurpose_chatbot/demos/base_demo.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ def create_class_func_registry():
27
+ registry = {}
28
+ def register_registry(cls, exist_ok=False):
29
+ assert exist_ok or cls.__name__ not in registry, f'{cls} already in registry: {registry}'
30
+ registry[cls.__name__] = cls
31
+ return cls
32
+
33
+ def get_registry(name):
34
+ assert name in registry, f'{name} not in registry: {registry}'
35
+ return registry[name]
36
+
37
+ return registry, register_registry, get_registry
38
+
39
+ DEMOS, register_demo, get_demo_class = create_class_func_registry()
40
+
41
+
42
+ class BaseDemo(object):
43
+ """
44
+ All demo should be created from BaseDemo and registered with @register_demo
45
+ """
46
+ def __init__(self) -> None:
47
+ pass
48
+
49
+ @property
50
+ def tab_name(self):
51
+ return "Demo"
52
+
53
+ def create_demo(
54
+ self,
55
+ title: Optional[str] = None,
56
+ description: Optional[str] = None,
57
+ **kwargs,
58
+ ) -> gr.Blocks:
59
+ pass
60
+
61
+
62
+ @document()
63
+ class CustomTabbedInterface(gr.Blocks):
64
+ def __init__(
65
+ self,
66
+ interface_list: list[gr.Interface],
67
+ tab_names: Optional[list[str]] = None,
68
+ title: Optional[str] = None,
69
+ description: Optional[str] = None,
70
+ theme: Optional[gr.Theme] = None,
71
+ analytics_enabled: Optional[bool] = None,
72
+ css: Optional[str] = None,
73
+ ):
74
+ """
75
+ Parameters:
76
+ interface_list: a list of interfaces to be rendered in tabs.
77
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
78
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
79
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
80
+ css: custom css or path to custom css file to apply to entire Blocks
81
+ Returns:
82
+ a Gradio Tabbed Interface for the given interfaces
83
+ """
84
+ super().__init__(
85
+ title=title or "Gradio",
86
+ theme=theme,
87
+ analytics_enabled=analytics_enabled,
88
+ mode="tabbed_interface",
89
+ css=css,
90
+ )
91
+ self.description = description
92
+ if tab_names is None:
93
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
94
+ with self:
95
+ if title:
96
+ gr.Markdown(
97
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
98
+ )
99
+ if description:
100
+ gr.Markdown(description)
101
+ with gr.Tabs():
102
+ for interface, tab_name in zip(interface_list, tab_names):
103
+ with gr.Tab(label=tab_name):
104
+ interface.render()
105
+
multipurpose_chatbot/demos/batch_inference.py ADDED
File without changes
multipurpose_chatbot/demos/chat_interface.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ import inspect
27
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
28
+
29
+ import anyio
30
+ from gradio_client import utils as client_utils
31
+ from gradio_client.documentation import document
32
+
33
+ from gradio.blocks import Blocks
34
+ from gradio.components import (
35
+ Button,
36
+ Chatbot,
37
+ Component,
38
+ Markdown,
39
+ State,
40
+ Textbox,
41
+ get_component_instance,
42
+ )
43
+ from gradio.events import Dependency, on
44
+ from gradio.helpers import create_examples as Examples # noqa: N812
45
+ from gradio.helpers import special_args
46
+ from gradio.layouts import Accordion, Group, Row
47
+ from gradio.routes import Request
48
+ from gradio.themes import ThemeClass as Theme
49
+ from gradio.utils import SyncToAsyncIterator, async_iteration
50
+
51
+
52
+ from .base_demo import register_demo, get_demo_class, BaseDemo
53
+ from ..configs import (
54
+ SYSTEM_PROMPT,
55
+ MODEL_NAME,
56
+ MAX_TOKENS,
57
+ TEMPERATURE,
58
+ )
59
+
60
+ from ..globals import MODEL_ENGINE
61
+
62
+ CHAT_EXAMPLES = [
63
+ ["Explain general relativity."],
64
+ ]
65
+ DATETIME_FORMAT = "Current date time: {cur_datetime}."
66
+
67
+
68
+ def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
69
+ conversations = []
70
+ system_prompt = system_prompt or SYSTEM_PROMPT
71
+ if history is not None and len(history) > 0:
72
+ for i, (prompt, res) in enumerate(history):
73
+ if prompt is not None:
74
+ conversations.append({"role": "user", "content": prompt.strip()})
75
+ if res is not None:
76
+ conversations.append({"role": "assistant", "content": res.strip()})
77
+ if message is not None:
78
+ if len(message.strip()) == 0:
79
+ raise gr.Error("The message cannot be empty!")
80
+ conversations.append({"role": "user", "content": message.strip()})
81
+ if conversations[0]['role'] != 'system':
82
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
83
+ return conversations
84
+
85
+
86
+ def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
87
+ global MODEL_ENGINE
88
+ full_prompt = MODEL_ENGINE.apply_chat_template(
89
+ gradio_history_to_openai_conversations(
90
+ message, history=history, system_prompt=system_prompt),
91
+ add_generation_prompt=True
92
+ )
93
+ return full_prompt
94
+
95
+
96
+
97
+ def get_datetime_string():
98
+ from datetime import datetime
99
+ now = datetime.now()
100
+ # dd/mm/YY H:M:S
101
+ dt_string = now.strftime("%B %d, %Y, %H:%M:%S")
102
+ return dt_string
103
+
104
+
105
+ def format_conversation(history, system_prompt=None):
106
+ _str = '\n'.join([
107
+ (
108
+ f'<<<User>>> {h[0]}\n'
109
+ f'<<<Asst>>> {h[1]}'
110
+ )
111
+ for h in history
112
+ ])
113
+ if system_prompt is not None:
114
+ _str = f"<<<System>>> {system_prompt}\n" + _str
115
+ return _str
116
+
117
+
118
+ def chat_response_stream_multiturn_engine(
119
+ message: str,
120
+ history: List[Tuple[str, str]],
121
+ temperature: float,
122
+ max_tokens: int,
123
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
124
+ ):
125
+ global MODEL_ENGINE
126
+ temperature = float(temperature)
127
+ # ! remove frequency_penalty
128
+ # frequency_penalty = float(frequency_penalty)
129
+ max_tokens = int(max_tokens)
130
+ message = message.strip()
131
+ if len(message) == 0:
132
+ raise gr.Error("The message cannot be empty!")
133
+ # ! skip safety
134
+ if DATETIME_FORMAT in system_prompt:
135
+ # ! This sometime works sometimes dont
136
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
137
+ full_prompt = gradio_history_to_conversation_prompt(message.strip(), history=history, system_prompt=system_prompt)
138
+ # ! length checked
139
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
140
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
141
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
142
+ print(full_prompt)
143
+ outputs = None
144
+ response = None
145
+ num_tokens = -1
146
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
147
+ prompt=full_prompt,
148
+ temperature=temperature,
149
+ max_tokens=max_tokens,
150
+ )):
151
+ if isinstance(outputs, tuple):
152
+ response, num_tokens = outputs
153
+ else:
154
+ response, num_tokens = outputs, -1
155
+ yield response, num_tokens
156
+
157
+ if response is not None:
158
+ yield response, num_tokens
159
+
160
+
161
+ class CustomizedChatInterface(gr.ChatInterface):
162
+ """
163
+ Fixing some issue with chatinterace
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ fn: Callable,
169
+ *,
170
+ chatbot: Chatbot | None = None,
171
+ textbox: Textbox | None = None,
172
+ additional_inputs: str | Component | list[str | Component] | None = None,
173
+ additional_inputs_accordion_name: str | None = None,
174
+ additional_inputs_accordion: str | Accordion | None = None,
175
+ examples: list[str] | None = None,
176
+ cache_examples: bool | None = None,
177
+ title: str | None = None,
178
+ description: str | None = None,
179
+ theme: Theme | str | None = None,
180
+ css: str | None = None,
181
+ js: str | None = None,
182
+ head: str | None = None,
183
+ analytics_enabled: bool | None = None,
184
+ submit_btn: str | None | Button = "Submit",
185
+ stop_btn: str | None | Button = "Stop",
186
+ retry_btn: str | None | Button = "🔄 Retry",
187
+ undo_btn: str | None | Button = "↩️ Undo",
188
+ clear_btn: str | None | Button = "🗑️ Clear",
189
+ autofocus: bool = True,
190
+ concurrency_limit: int | None | Literal["default"] = "default",
191
+ fill_height: bool = True,
192
+ ):
193
+ """
194
+ Parameters:
195
+ fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
196
+ chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
197
+ textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
198
+ additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
199
+ additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
200
+ additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
201
+ examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
202
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
203
+ title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
204
+ description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
205
+ theme: Theme to use, loaded from gradio.themes.
206
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
207
+ js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
208
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
209
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
210
+ submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
211
+ stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
212
+ retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
213
+ undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
214
+ clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
215
+ autofocus: If True, autofocuses to the textbox when the page loads.
216
+ concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
217
+ fill_height: If True, the chat interface will expand to the height of window.
218
+ """
219
+ try:
220
+ super(gr.ChatInterface, self).__init__(
221
+ analytics_enabled=analytics_enabled,
222
+ mode="chat_interface",
223
+ css=css,
224
+ title=title or "Gradio",
225
+ theme=theme,
226
+ js=js,
227
+ head=head,
228
+ fill_height=fill_height,
229
+ )
230
+ except Exception as e:
231
+ # Handling some old gradio version with out fill_height
232
+ super(gr.ChatInterface, self).__init__(
233
+ analytics_enabled=analytics_enabled,
234
+ mode="chat_interface",
235
+ css=css,
236
+ title=title or "Gradio",
237
+ theme=theme,
238
+ js=js,
239
+ head=head,
240
+ # fill_height=fill_height,
241
+ )
242
+ self.concurrency_limit = concurrency_limit
243
+ self.fn = fn
244
+ self.is_async = inspect.iscoroutinefunction(
245
+ self.fn
246
+ ) or inspect.isasyncgenfunction(self.fn)
247
+ self.is_generator = inspect.isgeneratorfunction(
248
+ self.fn
249
+ ) or inspect.isasyncgenfunction(self.fn)
250
+ self.examples = examples
251
+ if self.space_id and cache_examples is None:
252
+ self.cache_examples = True
253
+ else:
254
+ self.cache_examples = cache_examples or False
255
+ self.buttons: list[Button | None] = []
256
+
257
+ if additional_inputs:
258
+ if not isinstance(additional_inputs, list):
259
+ additional_inputs = [additional_inputs]
260
+ self.additional_inputs = [
261
+ get_component_instance(i)
262
+ for i in additional_inputs # type: ignore
263
+ ]
264
+ else:
265
+ self.additional_inputs = []
266
+ if additional_inputs_accordion_name is not None:
267
+ print(
268
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
269
+ )
270
+ self.additional_inputs_accordion_params = {
271
+ "label": additional_inputs_accordion_name
272
+ }
273
+ if additional_inputs_accordion is None:
274
+ self.additional_inputs_accordion_params = {
275
+ "label": "Additional Inputs",
276
+ "open": False,
277
+ }
278
+ elif isinstance(additional_inputs_accordion, str):
279
+ self.additional_inputs_accordion_params = {
280
+ "label": additional_inputs_accordion
281
+ }
282
+ elif isinstance(additional_inputs_accordion, Accordion):
283
+ self.additional_inputs_accordion_params = (
284
+ additional_inputs_accordion.recover_kwargs(
285
+ additional_inputs_accordion.get_config()
286
+ )
287
+ )
288
+ else:
289
+ raise ValueError(
290
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
291
+ )
292
+
293
+ with self:
294
+ if title:
295
+ Markdown(
296
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
297
+ )
298
+ if description:
299
+ Markdown(description)
300
+
301
+ if chatbot:
302
+ self.chatbot = chatbot.render()
303
+ else:
304
+ self.chatbot = Chatbot(
305
+ label="Chatbot", scale=1, height=200 if fill_height else None
306
+ )
307
+
308
+ with Row():
309
+ for btn in [retry_btn, undo_btn, clear_btn]:
310
+ if btn is not None:
311
+ if isinstance(btn, Button):
312
+ btn.render()
313
+ elif isinstance(btn, str):
314
+ btn = Button(btn, variant="secondary", size="sm")
315
+ else:
316
+ raise ValueError(
317
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
318
+ )
319
+ self.buttons.append(btn) # type: ignore
320
+
321
+ with Group():
322
+ with Row():
323
+ if textbox:
324
+ textbox.container = False
325
+ textbox.show_label = False
326
+ textbox_ = textbox.render()
327
+ assert isinstance(textbox_, Textbox)
328
+ self.textbox = textbox_
329
+ else:
330
+ self.textbox = Textbox(
331
+ container=False,
332
+ show_label=False,
333
+ label="Message",
334
+ placeholder="Type a message...",
335
+ scale=7,
336
+ autofocus=autofocus,
337
+ )
338
+ if submit_btn is not None:
339
+ if isinstance(submit_btn, Button):
340
+ submit_btn.render()
341
+ elif isinstance(submit_btn, str):
342
+ submit_btn = Button(
343
+ submit_btn,
344
+ variant="primary",
345
+ scale=2,
346
+ min_width=150,
347
+ )
348
+ else:
349
+ raise ValueError(
350
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
351
+ )
352
+ if stop_btn is not None:
353
+ if isinstance(stop_btn, Button):
354
+ stop_btn.visible = False
355
+ stop_btn.render()
356
+ elif isinstance(stop_btn, str):
357
+ stop_btn = Button(
358
+ stop_btn,
359
+ variant="stop",
360
+ visible=False,
361
+ scale=2,
362
+ min_width=150,
363
+ )
364
+ else:
365
+ raise ValueError(
366
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
367
+ )
368
+ self.num_tokens = Textbox(
369
+ container=False,
370
+ show_label=False,
371
+ label="num_tokens",
372
+ placeholder="0 tokens",
373
+ scale=1,
374
+ interactive=False,
375
+ # autofocus=autofocus,
376
+ min_width=10
377
+ )
378
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
379
+
380
+ self.fake_api_btn = Button("Fake API", visible=False)
381
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
382
+ (
383
+ self.retry_btn,
384
+ self.undo_btn,
385
+ self.clear_btn,
386
+ self.submit_btn,
387
+ self.stop_btn,
388
+ ) = self.buttons
389
+
390
+ if examples:
391
+ if self.is_generator:
392
+ examples_fn = self._examples_stream_fn
393
+ else:
394
+ examples_fn = self._examples_fn
395
+
396
+ self.examples_handler = Examples(
397
+ examples=examples,
398
+ inputs=[self.textbox] + self.additional_inputs,
399
+ outputs=self.chatbot,
400
+ fn=examples_fn,
401
+ )
402
+
403
+ any_unrendered_inputs = any(
404
+ not inp.is_rendered for inp in self.additional_inputs
405
+ )
406
+ if self.additional_inputs and any_unrendered_inputs:
407
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
408
+ for input_component in self.additional_inputs:
409
+ if not input_component.is_rendered:
410
+ input_component.render()
411
+
412
+ # The example caching must happen after the input components have rendered
413
+ if cache_examples:
414
+ client_utils.synchronize_async(self.examples_handler.cache)
415
+
416
+ self.saved_input = State()
417
+ self.chatbot_state = (
418
+ State(self.chatbot.value) if self.chatbot.value else State([])
419
+ )
420
+
421
+ self._setup_events()
422
+ self._setup_api()
423
+
424
+ # replace events so that submit button is disabled during generation, if stop_btn not found
425
+ # this prevent weird behavior
426
+ def _setup_stop_events(
427
+ self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
428
+ ) -> None:
429
+ from gradio.components import State
430
+ event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
431
+ if self.stop_btn and self.is_generator:
432
+ if self.submit_btn:
433
+ for event_trigger in event_triggers:
434
+ event_trigger(
435
+ lambda: (
436
+ Button(visible=False),
437
+ Button(visible=True),
438
+ ),
439
+ None,
440
+ [self.submit_btn, self.stop_btn],
441
+ api_name=False,
442
+ queue=False,
443
+ )
444
+ event_to_cancel.then(
445
+ lambda: (Button(visible=True), Button(visible=False)),
446
+ None,
447
+ [self.submit_btn, self.stop_btn],
448
+ api_name=False,
449
+ queue=False,
450
+ )
451
+ else:
452
+ for event_trigger in event_triggers:
453
+ event_trigger(
454
+ lambda: Button(visible=True),
455
+ None,
456
+ [self.stop_btn],
457
+ api_name=False,
458
+ queue=False,
459
+ )
460
+ event_to_cancel.then(
461
+ lambda: Button(visible=False),
462
+ None,
463
+ [self.stop_btn],
464
+ api_name=False,
465
+ queue=False,
466
+ )
467
+ self.stop_btn.click(
468
+ None,
469
+ None,
470
+ None,
471
+ cancels=event_to_cancel,
472
+ api_name=False,
473
+ )
474
+ else:
475
+ if self.submit_btn:
476
+ for event_trigger in event_triggers:
477
+ event_trigger(
478
+ lambda: Button(interactive=False),
479
+ None,
480
+ [self.submit_btn],
481
+ api_name=False,
482
+ queue=False,
483
+ )
484
+ event_to_cancel.then(
485
+ lambda: Button(interactive=True),
486
+ None,
487
+ [self.submit_btn],
488
+ api_name=False,
489
+ queue=False,
490
+ )
491
+ # upon clear, cancel the submit event as well
492
+ if self.clear_btn:
493
+ self.clear_btn.click(
494
+ lambda: ([], [], None, Button(interactive=True)),
495
+ None,
496
+ [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
497
+ queue=False,
498
+ api_name=False,
499
+ cancels=event_to_cancel,
500
+ )
501
+
502
+ def _setup_events(self) -> None:
503
+ from gradio.components import State
504
+ has_on = False
505
+ try:
506
+ from gradio.events import Dependency, EventListenerMethod, on
507
+ has_on = True
508
+ except ImportError as ie:
509
+ has_on = False
510
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
511
+ if not self.is_generator:
512
+ raise NotImplementedError(f'should use generator')
513
+
514
+ if has_on:
515
+ # new version
516
+ submit_triggers = (
517
+ [self.textbox.submit, self.submit_btn.click]
518
+ if self.submit_btn
519
+ else [self.textbox.submit]
520
+ )
521
+ submit_event = (
522
+ on(
523
+ submit_triggers,
524
+ self._clear_and_save_textbox,
525
+ [self.textbox],
526
+ [self.textbox, self.saved_input],
527
+ api_name=False,
528
+ queue=False,
529
+ )
530
+ .then(
531
+ self._display_input,
532
+ [self.saved_input, self.chatbot_state],
533
+ [self.chatbot, self.chatbot_state],
534
+ api_name=False,
535
+ queue=False,
536
+ )
537
+ .then(
538
+ submit_fn,
539
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
540
+ [self.chatbot, self.chatbot_state, self.num_tokens],
541
+ api_name=False,
542
+ )
543
+ )
544
+ self._setup_stop_events(submit_triggers, submit_event)
545
+ else:
546
+ raise ValueError(f'Better install new gradio version than 3.44.0')
547
+
548
+ if self.retry_btn:
549
+ retry_event = (
550
+ self.retry_btn.click(
551
+ self._delete_prev_fn,
552
+ [self.chatbot_state],
553
+ [self.chatbot, self.saved_input, self.chatbot_state],
554
+ api_name=False,
555
+ queue=False,
556
+ )
557
+ .then(
558
+ self._display_input,
559
+ [self.saved_input, self.chatbot_state],
560
+ [self.chatbot, self.chatbot_state],
561
+ api_name=False,
562
+ queue=False,
563
+ )
564
+ .then(
565
+ submit_fn,
566
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
567
+ [self.chatbot, self.chatbot_state, self.num_tokens],
568
+ api_name=False,
569
+ )
570
+ )
571
+ self._setup_stop_events([self.retry_btn.click], retry_event)
572
+
573
+ if self.undo_btn:
574
+ self.undo_btn.click(
575
+ self._delete_prev_fn,
576
+ [self.chatbot_state],
577
+ [self.chatbot, self.saved_input, self.chatbot_state],
578
+ api_name=False,
579
+ queue=False,
580
+ ).then(
581
+ lambda x: x,
582
+ [self.saved_input],
583
+ [self.textbox],
584
+ api_name=False,
585
+ queue=False,
586
+ )
587
+ # Reconfigure clear_btn to stop and clear text box
588
+
589
+ def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
590
+ return "", message
591
+
592
+ def _display_input(
593
+ self, message: str, history: List[List[Union[str, None]]]
594
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
595
+ if message is not None and message.strip() != "":
596
+ history.append([message, None])
597
+ return history, history
598
+
599
+ async def _stream_fn(
600
+ self,
601
+ message: str,
602
+ history_with_input,
603
+ request: Request,
604
+ *args,
605
+ ) -> AsyncGenerator:
606
+ history = history_with_input[:-1]
607
+ inputs, _, _ = special_args(
608
+ self.fn, inputs=[message, history, *args], request=request
609
+ )
610
+
611
+ if self.is_async:
612
+ generator = self.fn(*inputs)
613
+ else:
614
+ generator = await anyio.to_thread.run_sync(
615
+ self.fn, *inputs, limiter=self.limiter
616
+ )
617
+ generator = SyncToAsyncIterator(generator, self.limiter)
618
+
619
+ # ! In case of error, yield the previous history & undo any generation before raising error
620
+ try:
621
+ first_response_pack = await async_iteration(generator)
622
+ if isinstance(first_response_pack, (tuple, list)):
623
+ first_response, num_tokens = first_response_pack
624
+ else:
625
+ first_response, num_tokens = first_response_pack, -1
626
+ update = history + [[message, first_response]]
627
+ yield update, update, f"{num_tokens} toks"
628
+ except StopIteration:
629
+ update = history + [[message, None]]
630
+ yield update, update, "NaN toks"
631
+ except Exception as e:
632
+ yield history, history, "NaN toks"
633
+ raise e
634
+
635
+ try:
636
+ async for response_pack in generator:
637
+ if isinstance(response_pack, (tuple, list)):
638
+ response, num_tokens = response_pack
639
+ else:
640
+ response, num_tokens = response_pack, "NaN toks"
641
+ update = history + [[message, response]]
642
+ yield update, update, f"{num_tokens} toks"
643
+ except Exception as e:
644
+ yield history, history, "NaN toks"
645
+ raise e
646
+
647
+ @register_demo
648
+ class ChatInterfaceDemo(BaseDemo):
649
+ @property
650
+ def tab_name(self):
651
+ return "Chat"
652
+
653
+ def create_demo(
654
+ self,
655
+ title: str | None = None,
656
+ description: str | None = None,
657
+ **kwargs
658
+ ) -> gr.Blocks:
659
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
660
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
661
+ temperature = kwargs.get("temperature", TEMPERATURE)
662
+ model_name = kwargs.get("model_name", MODEL_NAME)
663
+ # frequence_penalty = FREQUENCE_PENALTY
664
+ # presence_penalty = PRESENCE_PENALTY
665
+
666
+ demo_chat = CustomizedChatInterface(
667
+ chat_response_stream_multiturn_engine,
668
+ chatbot=gr.Chatbot(
669
+ label=model_name,
670
+ bubble_full_width=False,
671
+ latex_delimiters=[
672
+ { "left": "$", "right": "$", "display": False},
673
+ { "left": "$$", "right": "$$", "display": True},
674
+ ],
675
+ show_copy_button=True,
676
+ ),
677
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
678
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
679
+ title=title,
680
+ description=description,
681
+ additional_inputs=[
682
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
683
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
684
+ # gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
685
+ # gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
686
+ gr.Textbox(value=system_prompt, label='System prompt', lines=4)
687
+ ],
688
+ examples=CHAT_EXAMPLES,
689
+ cache_examples=False
690
+ )
691
+ return demo_chat
692
+
multipurpose_chatbot/demos/multimodal_chat_interface.py ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ from gradio.components.base import Component
25
+
26
+ from .base_demo import register_demo, get_demo_class, BaseDemo
27
+
28
+
29
+ from .chat_interface import (
30
+ SYSTEM_PROMPT,
31
+ MODEL_NAME,
32
+ MAX_TOKENS,
33
+ TEMPERATURE,
34
+ CHAT_EXAMPLES,
35
+ gradio_history_to_openai_conversations,
36
+ gradio_history_to_conversation_prompt,
37
+ DATETIME_FORMAT,
38
+ get_datetime_string,
39
+ chat_response_stream_multiturn_engine,
40
+ ChatInterfaceDemo,
41
+ CustomizedChatInterface,
42
+ )
43
+
44
+ from gradio.events import Events
45
+
46
+ import inspect
47
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
48
+
49
+ import anyio
50
+ from gradio_client import utils as client_utils
51
+ from gradio_client.documentation import document
52
+
53
+ from gradio.blocks import Blocks
54
+ from gradio.components import (
55
+ Button,
56
+ Chatbot,
57
+ Component,
58
+ Markdown,
59
+ State,
60
+ Textbox,
61
+ get_component_instance,
62
+ )
63
+ from gradio.events import Dependency, on
64
+ from gradio.helpers import create_examples as Examples # noqa: N812
65
+ from gradio.helpers import special_args
66
+ from gradio.layouts import Accordion, Group, Row
67
+ from gradio.routes import Request
68
+ from gradio.themes import ThemeClass as Theme
69
+ from gradio.utils import SyncToAsyncIterator, async_iteration
70
+
71
+ from ..globals import MODEL_ENGINE
72
+
73
+ from ..configs import (
74
+ USE_PANEL,
75
+ IMAGE_TOKEN,
76
+ IMAGE_TOKEN_INTERACTIVE,
77
+ CHATBOT_HEIGHT,
78
+ )
79
+
80
+
81
+
82
+ CSS = """
83
+ .message-fit {
84
+ min-width: 20em;
85
+ width: fit-content !important;
86
+ }
87
+
88
+ .message.svelte-1lcyrx4.svelte-1lcyrx4.svelte-1lcyrx4 {
89
+ padding-top: 1em;
90
+ padding-bottom: 1em;
91
+ }
92
+ """
93
+
94
+
95
+ DOC_TEMPLATE = """###
96
+ {content}
97
+ ###
98
+
99
+ """
100
+
101
+ DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
102
+ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
103
+ """
104
+
105
+
106
+ def undo_history(history):
107
+ if len(history) == 0:
108
+ return history
109
+ if history[-1][-1] is not None:
110
+ if history[-1][0] is not None:
111
+ history[-1][-1] = None
112
+ else:
113
+ history = history[:-1]
114
+ else:
115
+ history = history[:-1]
116
+ return history
117
+
118
+
119
+ def undo_history_until_last_assistant_turn(history):
120
+ history = undo_history(history)
121
+ while len(history) > 0 and history[-1][-1] is None:
122
+ history = undo_history(history)
123
+ return history, history
124
+
125
+
126
+ class MultiModalChatInterface(CustomizedChatInterface):
127
+ def __init__(
128
+ self,
129
+ fn: Callable,
130
+ *,
131
+ chatbot: Chatbot | None = None,
132
+ textbox: Textbox | None = None,
133
+ additional_inputs: str | Component | list[str | Component] | None = None,
134
+ additional_inputs_accordion_name: str | None = None,
135
+ additional_inputs_accordion: str | Accordion | None = None,
136
+ add_multimodal_fn: Callable | None = None,
137
+ render_additional_inputs_fn: Callable | None = None,
138
+ examples: list[str] | None = None,
139
+ cache_examples: bool | None = None,
140
+ title: str | None = None,
141
+ description: str | None = None,
142
+ theme: Theme | str | None = None,
143
+ css: str | None = None,
144
+ js: str | None = None,
145
+ head: str | None = None,
146
+ analytics_enabled: bool | None = None,
147
+ submit_btn: str | None | Button = "Submit",
148
+ stop_btn: str | None | Button = "Stop",
149
+ retry_btn: str | None | Button = "🔄 Retry",
150
+ undo_btn: str | None | Button = "↩️ Undo",
151
+ clear_btn: str | None | Button = "🗑️ Clear",
152
+ autofocus: bool = True,
153
+ concurrency_limit: int | None | Literal["default"] = "default",
154
+ fill_height: bool = True,
155
+ ):
156
+ """
157
+ Parameters:
158
+ fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
159
+ chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
160
+ textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
161
+ additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
162
+ additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
163
+ additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
164
+ examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
165
+ cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
166
+ title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
167
+ description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
168
+ theme: Theme to use, loaded from gradio.themes.
169
+ css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
170
+ js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
171
+ head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
172
+ analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
173
+ submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
174
+ stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
175
+ retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
176
+ undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
177
+ clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
178
+ autofocus: If True, autofocuses to the textbox when the page loads.
179
+ concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
180
+ fill_height: If True, the chat interface will expand to the height of window.
181
+ """
182
+ try:
183
+ super(gr.ChatInterface, self).__init__(
184
+ analytics_enabled=analytics_enabled,
185
+ mode="chat_interface",
186
+ css=css,
187
+ title=title or "Gradio",
188
+ theme=theme,
189
+ js=js,
190
+ head=head,
191
+ fill_height=fill_height,
192
+ )
193
+ except Exception as e:
194
+ # Handle old gradio versions without fill_height
195
+ super(gr.ChatInterface, self).__init__(
196
+ analytics_enabled=analytics_enabled,
197
+ mode="chat_interface",
198
+ css=css,
199
+ title=title or "Gradio",
200
+ theme=theme,
201
+ js=js,
202
+ head=head,
203
+ # fill_height=fill_height,
204
+ )
205
+
206
+ self.concurrency_limit = concurrency_limit
207
+ self.fn = fn
208
+ self.add_multimodal_fn = add_multimodal_fn
209
+ self.render_additional_inputs_fn = render_additional_inputs_fn
210
+ self.multimodal_inputs = []
211
+ self.is_async = inspect.iscoroutinefunction(
212
+ self.fn
213
+ ) or inspect.isasyncgenfunction(self.fn)
214
+ self.is_generator = inspect.isgeneratorfunction(
215
+ self.fn
216
+ ) or inspect.isasyncgenfunction(self.fn)
217
+ self.examples = examples
218
+ if self.space_id and cache_examples is None:
219
+ self.cache_examples = True
220
+ else:
221
+ self.cache_examples = cache_examples or False
222
+ self.buttons: list[Button | None] = []
223
+
224
+ if additional_inputs:
225
+ if not isinstance(additional_inputs, list):
226
+ additional_inputs = [additional_inputs]
227
+ self.additional_inputs = [
228
+ get_component_instance(i)
229
+ for i in additional_inputs # type: ignore
230
+ ]
231
+ else:
232
+ self.additional_inputs = []
233
+ if additional_inputs_accordion_name is not None:
234
+ print(
235
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
236
+ )
237
+ self.additional_inputs_accordion_params = {
238
+ "label": additional_inputs_accordion_name
239
+ }
240
+ if additional_inputs_accordion is None:
241
+ self.additional_inputs_accordion_params = {
242
+ "label": "Additional Inputs",
243
+ "open": False,
244
+ }
245
+ elif isinstance(additional_inputs_accordion, str):
246
+ self.additional_inputs_accordion_params = {
247
+ "label": additional_inputs_accordion
248
+ }
249
+ elif isinstance(additional_inputs_accordion, Accordion):
250
+ self.additional_inputs_accordion_params = (
251
+ additional_inputs_accordion.recover_kwargs(
252
+ additional_inputs_accordion.get_config()
253
+ )
254
+ )
255
+ else:
256
+ raise ValueError(
257
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
258
+ )
259
+
260
+ with self:
261
+ if title:
262
+ Markdown(
263
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
264
+ )
265
+ if description:
266
+ Markdown(description)
267
+
268
+ if chatbot:
269
+ self.chatbot = chatbot.render()
270
+ else:
271
+ self.chatbot = Chatbot(
272
+ label="Chatbot", scale=1, height=200 if fill_height else None
273
+ )
274
+
275
+ with Row():
276
+ for btn in [retry_btn, undo_btn, clear_btn]:
277
+ if btn is not None:
278
+ if isinstance(btn, Button):
279
+ btn.render()
280
+ elif isinstance(btn, str):
281
+ btn = Button(btn, variant="secondary", size="sm")
282
+ else:
283
+ raise ValueError(
284
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
285
+ )
286
+ self.buttons.append(btn) # type: ignore
287
+
288
+ with Group():
289
+ with Row():
290
+ if textbox:
291
+ textbox.container = False
292
+ textbox.show_label = False
293
+ textbox_ = textbox.render()
294
+ assert isinstance(textbox_, Textbox)
295
+ self.textbox = textbox_
296
+ else:
297
+ self.textbox = Textbox(
298
+ container=False,
299
+ show_label=False,
300
+ label="Message",
301
+ placeholder="Type a message...",
302
+ scale=7,
303
+ autofocus=autofocus,
304
+ )
305
+ if submit_btn is not None:
306
+ if isinstance(submit_btn, Button):
307
+ submit_btn.render()
308
+ elif isinstance(submit_btn, str):
309
+ submit_btn = Button(
310
+ submit_btn,
311
+ variant="primary",
312
+ scale=2,
313
+ min_width=150,
314
+ )
315
+ else:
316
+ raise ValueError(
317
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
318
+ )
319
+ if stop_btn is not None:
320
+ if isinstance(stop_btn, Button):
321
+ stop_btn.visible = False
322
+ stop_btn.render()
323
+ elif isinstance(stop_btn, str):
324
+ stop_btn = Button(
325
+ stop_btn,
326
+ variant="stop",
327
+ visible=False,
328
+ scale=2,
329
+ min_width=150,
330
+ )
331
+ else:
332
+ raise ValueError(
333
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
334
+ )
335
+ self.num_tokens = Textbox(
336
+ container=False,
337
+ show_label=False,
338
+ label="num_tokens",
339
+ placeholder="0 tokens",
340
+ scale=1,
341
+ interactive=False,
342
+ # autofocus=autofocus,
343
+ min_width=10
344
+ )
345
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
346
+
347
+ self.fake_api_btn = Button("Fake API", visible=False)
348
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
349
+ (
350
+ self.retry_btn,
351
+ self.undo_btn,
352
+ self.clear_btn,
353
+ self.submit_btn,
354
+ self.stop_btn,
355
+ ) = self.buttons
356
+
357
+
358
+ any_unrendered_inputs = any(
359
+ not inp.is_rendered for inp in self.additional_inputs
360
+ )
361
+ if self.add_multimodal_fn is not None:
362
+ with Row():
363
+ self.multimodal_inputs = self.add_multimodal_fn()
364
+ if self.additional_inputs and any_unrendered_inputs:
365
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
366
+ if self.render_additional_inputs_fn is not None:
367
+ self.render_additional_inputs_fn()
368
+ else:
369
+ for input_component in self.additional_inputs:
370
+ if not input_component.is_rendered:
371
+ input_component.render()
372
+ else:
373
+ if self.additional_inputs and any_unrendered_inputs:
374
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
375
+ if self.render_additional_inputs_fn is not None:
376
+ self.render_additional_inputs_fn()
377
+ else:
378
+ for input_component in self.additional_inputs:
379
+ if not input_component.is_rendered:
380
+ input_component.render()
381
+
382
+ if examples:
383
+ if self.is_generator:
384
+ examples_fn = self._examples_stream_fn
385
+ else:
386
+ # examples_fn = self._examples_fn
387
+ raise NotImplementedError(f'Not streaming not impl')
388
+
389
+ self.examples_handler = Examples(
390
+ examples=examples,
391
+ inputs=[self.textbox] + self.multimodal_inputs + self.additional_inputs,
392
+ outputs=self.chatbot,
393
+ fn=examples_fn,
394
+ )
395
+
396
+ # The example caching must happen after the input components have rendered
397
+ if cache_examples:
398
+ client_utils.synchronize_async(self.examples_handler.cache)
399
+
400
+ self.saved_input = State()
401
+ self.chatbot_state = (
402
+ State(self.chatbot.value) if self.chatbot.value else State([])
403
+ )
404
+
405
+ self._setup_events()
406
+ self._setup_api()
407
+
408
+ def _clear_and_save_textbox(self, message: str, *multimodal_inputs) -> tuple[str, str]:
409
+ saved_input = [message] + list(multimodal_inputs)
410
+ outputs = [''] + [None] * len(multimodal_inputs)
411
+ return outputs + [saved_input]
412
+
413
+ def _add_inputs_to_history(self, history: List[List[Union[str, None]]], *args):
414
+ message = args[0]
415
+ multimodal_inputs = args[1:1 + len(self.multimodal_inputs)] if len(args) > 1 else None
416
+ if multimodal_inputs is not None:
417
+ is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
418
+ if any(is_file_exists):
419
+ file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
420
+ if len(file_exists) > 1:
421
+ raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
422
+ fname = file_exists[0]
423
+ history.append([(fname,), None])
424
+ if message is not None and message.strip() != "":
425
+ history.append([message, None])
426
+ return history
427
+
428
+
429
+ def _display_input(
430
+ self, saved_input: List[str], history: List[List[Union[str, None]]]
431
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
432
+ # message = saved_input[0]
433
+ # multimodal_inputs = saved_input[1:] if len(saved_input) > 1 else None
434
+ # # ! If things wrong, return original history and give warning
435
+ # if multimodal_inputs is not None:
436
+ # is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
437
+ # if any(is_file_exists):
438
+ # file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
439
+ # if len(file_exists) > 1:
440
+ # raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
441
+ # fname = file_exists[0]
442
+ # history.append([(fname,), None])
443
+ # if message is not None and message.strip() != "":
444
+ # history.append([message, None])
445
+ history = self._add_inputs_to_history(history, *saved_input)
446
+ return history, history
447
+
448
+ def _delete_prev_fn(
449
+ self, history: list[list[str | None]]
450
+ ) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
451
+ try:
452
+ message, _ = history.pop()
453
+ except IndexError:
454
+ message = ""
455
+ saved_input = [message or ""] + [None] * len(self.multimodal_inputs)
456
+ return history, saved_input, history
457
+
458
+ def _setup_events(self) -> None:
459
+ from gradio.components import State
460
+ has_on = False
461
+ try:
462
+ from gradio.events import Dependency, EventListenerMethod, on
463
+ has_on = True
464
+ except ImportError as ie:
465
+ has_on = False
466
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
467
+ if not self.is_generator:
468
+ raise NotImplementedError(f'should use generator')
469
+
470
+ if has_on:
471
+ # new version
472
+ submit_triggers = (
473
+ [self.textbox.submit, self.submit_btn.click]
474
+ if self.submit_btn
475
+ else [self.textbox.submit]
476
+ )
477
+ submit_event = (
478
+ on(
479
+ submit_triggers,
480
+ self._clear_and_save_textbox,
481
+ [self.textbox] + self.multimodal_inputs,
482
+ [self.textbox] + self.multimodal_inputs + [self.saved_input],
483
+ api_name=False,
484
+ queue=False,
485
+ )
486
+ .then(
487
+ self._display_input,
488
+ [self.saved_input, self.chatbot_state],
489
+ [self.chatbot, self.chatbot_state],
490
+ api_name=False,
491
+ queue=False,
492
+ )
493
+ .success(
494
+ submit_fn,
495
+ [self.chatbot_state] + self.additional_inputs,
496
+ [self.chatbot, self.chatbot_state, self.num_tokens],
497
+ api_name=False,
498
+ )
499
+ )
500
+ self._setup_stop_events(submit_triggers, submit_event)
501
+ else:
502
+ raise ValueError(f'Better install new gradio version than 3.44.0')
503
+
504
+ if self.retry_btn:
505
+ retry_event = (
506
+ self.retry_btn.click(
507
+ self._delete_prev_fn,
508
+ [self.chatbot_state],
509
+ [self.chatbot, self.saved_input, self.chatbot_state],
510
+ api_name=False,
511
+ queue=False,
512
+ )
513
+ .then(
514
+ self._display_input,
515
+ [self.saved_input, self.chatbot_state],
516
+ [self.chatbot, self.chatbot_state],
517
+ api_name=False,
518
+ queue=False,
519
+ )
520
+ .success(
521
+ submit_fn,
522
+ [self.chatbot_state] + self.additional_inputs,
523
+ [self.chatbot, self.chatbot_state, self.num_tokens],
524
+ api_name=False,
525
+ )
526
+ )
527
+ self._setup_stop_events([self.retry_btn.click], retry_event)
528
+
529
+ if self.undo_btn:
530
+ self.undo_btn.click(
531
+ # self._delete_prev_fn,
532
+ # [self.chatbot_state],
533
+ # [self.chatbot, self.saved_input, self.chatbot_state],
534
+ undo_history_until_last_assistant_turn,
535
+ [self.chatbot_state],
536
+ [self.chatbot, self.chatbot_state],
537
+ api_name=False,
538
+ queue=False,
539
+ )
540
+ # .then(
541
+ # lambda x: x,
542
+ # [self.saved_input],
543
+ # [self.textbox],
544
+ # api_name=False,
545
+ # queue=False,
546
+ # )
547
+
548
+ async def _stream_fn(
549
+ self,
550
+ # message: str,
551
+ history_with_input,
552
+ request: Request,
553
+ *args,
554
+ ) -> AsyncGenerator:
555
+ history = history_with_input[:-1]
556
+ message = history_with_input[-1][0]
557
+ inputs, _, _ = special_args(
558
+ self.fn, inputs=[history_with_input, *args], request=request
559
+ )
560
+
561
+ if self.is_async:
562
+ generator = self.fn(*inputs)
563
+ else:
564
+ generator = await anyio.to_thread.run_sync(
565
+ self.fn, *inputs, limiter=self.limiter
566
+ )
567
+ generator = SyncToAsyncIterator(generator, self.limiter)
568
+
569
+ # ! In case of error, yield the previous history & undo any generation before raising error
570
+ try:
571
+ first_response_pack = await async_iteration(generator)
572
+ if isinstance(first_response_pack, (tuple, list)):
573
+ first_response, num_tokens = first_response_pack
574
+ else:
575
+ first_response, num_tokens = first_response_pack, -1
576
+ update = history + [[message, first_response]]
577
+ yield update, update, f"{num_tokens} toks"
578
+ except StopIteration:
579
+ update = history + [[message, None]]
580
+ yield update, update, "NaN toks"
581
+ except Exception as e:
582
+ yield history, history, "NaN toks"
583
+ raise e
584
+
585
+ try:
586
+ async for response_pack in generator:
587
+ if isinstance(response_pack, (tuple, list)):
588
+ response, num_tokens = response_pack
589
+ else:
590
+ response, num_tokens = response_pack, "NaN toks"
591
+ update = history + [[message, response]]
592
+ yield update, update, f"{num_tokens} toks"
593
+ except Exception as e:
594
+ yield history, history, "NaN toks"
595
+ raise e
596
+
597
+ async def _examples_stream_fn(
598
+ self,
599
+ # message: str,
600
+ *args,
601
+ ) -> AsyncGenerator:
602
+ history = []
603
+ input_len = 1 + len(self.multimodal_inputs)
604
+ saved_input = args[:input_len]
605
+ message = saved_input[0]
606
+ additional_inputs = [] if len(args) <= input_len else args[input_len:]
607
+ history = self._add_inputs_to_history(history, *saved_input)
608
+ inputs, _, _ = special_args(self.fn, inputs=[history, *additional_inputs], request=None)
609
+
610
+ if self.is_async:
611
+ generator = self.fn(*inputs)
612
+ else:
613
+ generator = await anyio.to_thread.run_sync(
614
+ self.fn, *inputs, limiter=self.limiter
615
+ )
616
+ generator = SyncToAsyncIterator(generator, self.limiter)
617
+ # async for response in generator:
618
+ # yield [[message, response]]
619
+
620
+ try:
621
+ async for response_pack in generator:
622
+ if isinstance(response_pack, (tuple, list)):
623
+ response, num_tokens = response_pack
624
+ else:
625
+ response, num_tokens = response_pack, "NaN toks"
626
+ update = history + [[message, response]]
627
+ yield update, update, f"{num_tokens} toks"
628
+ except Exception as e:
629
+ yield history, history, "NaN toks"
630
+ raise e
631
+
632
+ async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
633
+ raise NotImplementedError
634
+ inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
635
+
636
+ if self.is_async:
637
+ response = await self.fn(*inputs)
638
+ else:
639
+ response = await anyio.to_thread.run_sync(
640
+ self.fn, *inputs, limiter=self.limiter
641
+ )
642
+ return [[message, response]]
643
+
644
+
645
+
646
+ def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
647
+ conversations = []
648
+ system_prompt = system_prompt or SYSTEM_PROMPT
649
+ if history is not None and len(history) > 0:
650
+ for i, (prompt, res) in enumerate(history):
651
+ if prompt is not None:
652
+ conversations.append({"role": "user", "content": prompt.strip()})
653
+ if res is not None:
654
+ conversations.append({"role": "assistant", "content": res.strip()})
655
+ if message is not None:
656
+ if len(message.strip()) == 0:
657
+ raise gr.Error("The message cannot be empty!")
658
+ conversations.append({"role": "user", "content": message.strip()})
659
+ if conversations[0]['role'] != 'system':
660
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
661
+ return conversations
662
+
663
+
664
+ def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
665
+ global MODEL_ENGINE
666
+ full_prompt = MODEL_ENGINE.apply_chat_template(
667
+ gradio_history_to_openai_conversations(
668
+ message, history=history, system_prompt=system_prompt),
669
+ add_generation_prompt=True
670
+ )
671
+ return full_prompt
672
+
673
+
674
+ def gradio_history_to_vision_conversations_paths(
675
+ history, system_prompt=None, image_token=None
676
+ ):
677
+ image_token = image_token or IMAGE_TOKEN
678
+ conversations = []
679
+ image_paths = []
680
+ for i, his in enumerate(history):
681
+ prompt, response = his
682
+ last_turn = conversations[-1] if len(conversations) > 0 else None
683
+ if prompt is not None:
684
+ if isinstance(prompt, tuple):
685
+ image_path = prompt[0]
686
+ if last_turn is not None and last_turn['role'] == 'user':
687
+ last_turn['content'] += f" {image_token}"
688
+ else:
689
+ # last_turn None or last_turn['role'] == 'assistant'
690
+ conversations.append({
691
+ "role": "user",
692
+ "content": f"{image_token}"
693
+ })
694
+ image_paths.append(image_path)
695
+ else:
696
+ assert prompt is not None and isinstance(prompt, str)
697
+ if last_turn is not None and last_turn['role'] == 'user':
698
+ last_turn['content'] += f"\n{prompt}"
699
+ else:
700
+ conversations.append({
701
+ "role": "user",
702
+ "content": prompt,
703
+ })
704
+ if response is not None:
705
+ assert isinstance(response, str)
706
+ conversations.append({
707
+ "role": "assistant",
708
+ "content": response,
709
+ })
710
+
711
+ if conversations[0]['role'] != 'system':
712
+ system_prompt = system_prompt or SYSTEM_PROMPT
713
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
714
+ return conversations, image_paths
715
+
716
+
717
+
718
+ def gradio_history_to_vision_conversation_prompt_paths(
719
+ history, system_prompt=None, image_token=None
720
+ ):
721
+ """
722
+ Aggregate gradio history into openai conversations
723
+ history = [
724
+ ["Hello", "Response"],
725
+ [(file,), None],
726
+ ]
727
+ --->
728
+ [
729
+ {"role": "user", "content": ...}
730
+ ]
731
+ """
732
+ global MODEL_ENGINE
733
+
734
+ conversations, image_paths = gradio_history_to_vision_conversations_paths(
735
+ history, system_prompt, image_token
736
+ )
737
+ # print(f'convo: {json.dumps(conversations, indent=4, ensure_ascii=False)}\n{image_paths=}')
738
+ full_prompt = MODEL_ENGINE.apply_chat_template(
739
+ conversations,
740
+ add_generation_prompt=True
741
+ )
742
+ return full_prompt, image_paths, conversations
743
+
744
+
745
+ def is_doc(file_path):
746
+ is_doc_allowed = file_path.endswith((".pdf", ".docx", ".txt"))
747
+ return is_doc_allowed
748
+
749
+
750
+ def read_doc(file_path):
751
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
752
+ if file_path.endswith('.pdf'):
753
+ loader = PyPDFLoader(file_path)
754
+ elif file_path.endswith('.docx'):
755
+ loader = Docx2txtLoader(file_path)
756
+ elif file_path.endswith('.txt'):
757
+ loader = TextLoader(file_path)
758
+ texts = loader.load()
759
+ text = "\n\n".join([t.page_content for t in texts])
760
+ return text
761
+
762
+
763
+ def doc_file_to_instruct_content(file_path, doc_instruction=None):
764
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
765
+ content = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=read_doc(file_path))
766
+ return content
767
+
768
+
769
+ def gradio_history_to_doc_conversation_prompt(
770
+ history, system_prompt=None, doc_instruction=None,
771
+ ):
772
+ """
773
+ Aggregate gradio history into openai conversations
774
+ history = [
775
+ ["Hello", "Response"],
776
+ [(file,), None],
777
+ ]
778
+ --->
779
+ [
780
+ {"role": "user", "content": ...}
781
+ ]
782
+ """
783
+ global MODEL_ENGINE
784
+ # image_token = image_token or IMAGE_TOKEN
785
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
786
+ conversations = []
787
+ image_paths = []
788
+ for i, his in enumerate(history):
789
+ prompt, response = his
790
+ last_turn = conversations[-1] if len(conversations) > 0 else None
791
+ if prompt is not None:
792
+ if isinstance(prompt, tuple):
793
+ file_path = prompt[0]
794
+ if not is_doc(file_path):
795
+ raise gr.Error(f'file not doc {file_path}')
796
+ content = doc_file_to_instruct_content(file_path, doc_instruction)
797
+ if last_turn is not None and last_turn['role'] == 'user':
798
+ last_turn['content'] += f"{content}"
799
+ else:
800
+ # last_turn None or last_turn['role'] == 'assistant'
801
+ conversations.append({
802
+ "role": "user",
803
+ "content": f"{content}"
804
+ })
805
+ else:
806
+ assert prompt is not None and isinstance(prompt, str)
807
+ if last_turn is not None and last_turn['role'] == 'user':
808
+ last_turn['content'] += f"\n{prompt}"
809
+ else:
810
+ conversations.append({
811
+ "role": "user",
812
+ "content": prompt,
813
+ })
814
+ if response is not None:
815
+ assert isinstance(response, str)
816
+ conversations.append({
817
+ "role": "assistant",
818
+ "content": response,
819
+ })
820
+
821
+ if conversations[0]['role'] != 'system':
822
+ system_prompt = system_prompt or SYSTEM_PROMPT
823
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
824
+
825
+ full_prompt = MODEL_ENGINE.apply_chat_template(
826
+ conversations,
827
+ add_generation_prompt=True
828
+ )
829
+ return full_prompt, conversations
830
+
831
+
832
+ def gradio_history_to_vision_doc_conversation_prompt_paths(
833
+ history, system_prompt=None, image_token=None, doc_instruction=None,
834
+ ):
835
+ """
836
+ Aggregate gradio history into openai conversations
837
+ history = [
838
+ ["Hello", "Response"],
839
+ [(file,), None],
840
+ ]
841
+ --->
842
+ [
843
+ {"role": "user", "content": ...}
844
+ ]
845
+ """
846
+ global MODEL_ENGINE
847
+ image_token = image_token or IMAGE_TOKEN
848
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
849
+ conversations = []
850
+ image_paths = []
851
+ for i, his in enumerate(history):
852
+ prompt, response = his
853
+ last_turn = conversations[-1] if len(conversations) > 0 else None
854
+ if prompt is not None:
855
+ if isinstance(prompt, tuple):
856
+ file_path = prompt[0]
857
+ if is_doc(file_path):
858
+ content = doc_file_to_instruct_content(file_path, doc_instruction)
859
+ if last_turn is not None and last_turn['role'] == 'user':
860
+ last_turn['content'] += f"{content}"
861
+ else:
862
+ # last_turn None or last_turn['role'] == 'assistant'
863
+ conversations.append({
864
+ "role": "user",
865
+ "content": f"{content}"
866
+ })
867
+ else:
868
+ if last_turn is not None and last_turn['role'] == 'user':
869
+ last_turn['content'] += f" {image_token}"
870
+ else:
871
+ # last_turn None or last_turn['role'] == 'assistant'
872
+ conversations.append({
873
+ "role": "user",
874
+ "content": f"{image_token}"
875
+ })
876
+ image_paths.append(file_path)
877
+ else:
878
+ assert prompt is not None and isinstance(prompt, str)
879
+ if last_turn is not None and last_turn['role'] == 'user':
880
+ last_turn['content'] += f"\n{prompt}"
881
+ else:
882
+ conversations.append({
883
+ "role": "user",
884
+ "content": prompt,
885
+ })
886
+ if response is not None:
887
+ assert isinstance(response, str)
888
+ conversations.append({
889
+ "role": "assistant",
890
+ "content": response,
891
+ })
892
+
893
+ if conversations[0]['role'] != 'system':
894
+ system_prompt = system_prompt or SYSTEM_PROMPT
895
+ conversations = [{"role": "system", "content": system_prompt}] + conversations
896
+
897
+ full_prompt = MODEL_ENGINE.apply_chat_template(
898
+ conversations,
899
+ add_generation_prompt=True
900
+ )
901
+ return full_prompt, image_paths, conversations
902
+
903
+
904
+ def vision_chat_response_stream_multiturn_engine(
905
+ history: List[Tuple[str, str]],
906
+ temperature: float,
907
+ max_tokens: int,
908
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
909
+ image_token: Optional[str] = IMAGE_TOKEN,
910
+ ):
911
+ global MODEL_ENGINE
912
+ temperature = float(temperature)
913
+ # ! remove frequency_penalty
914
+ # frequency_penalty = float(frequency_penalty)
915
+ max_tokens = int(max_tokens)
916
+ # ! skip safety
917
+ if DATETIME_FORMAT in system_prompt:
918
+ # ! This sometime works sometimes dont
919
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
920
+ # ! history now can have multimodal
921
+
922
+ full_prompt, image_paths, conversations = gradio_history_to_vision_conversation_prompt_paths(
923
+ history=history, system_prompt=system_prompt, image_token=image_token
924
+ )
925
+
926
+ if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
927
+ num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
928
+ else:
929
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
930
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
931
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
932
+
933
+ print(f'{image_paths=}')
934
+ print(full_prompt)
935
+ outputs = None
936
+ response = None
937
+ num_tokens = -1
938
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
939
+ prompt=full_prompt,
940
+ temperature=temperature,
941
+ max_tokens=max_tokens,
942
+ image_paths=image_paths,
943
+ )):
944
+ if isinstance(outputs, tuple):
945
+ response, num_tokens = outputs
946
+ else:
947
+ response, num_tokens = outputs, -1
948
+ yield response, num_tokens
949
+
950
+ if response is not None:
951
+ yield response, num_tokens
952
+
953
+
954
+ def doc_chat_response_stream_multiturn_engine(
955
+ history: List[Tuple[str, str]],
956
+ temperature: float,
957
+ max_tokens: int,
958
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
959
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
960
+ ):
961
+ global MODEL_ENGINE
962
+ temperature = float(temperature)
963
+ # ! remove frequency_penalty
964
+ # frequency_penalty = float(frequency_penalty)
965
+ max_tokens = int(max_tokens)
966
+ # ! skip safety
967
+ if DATETIME_FORMAT in system_prompt:
968
+ # ! This sometime works sometimes dont
969
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
970
+ # ! history now can have multimodal
971
+
972
+ full_prompt, conversations = gradio_history_to_doc_conversation_prompt(
973
+ history=history, system_prompt=system_prompt, doc_instruction=doc_instruction
974
+ )
975
+
976
+ # ! length checked
977
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
978
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
979
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
980
+
981
+ print(full_prompt)
982
+ outputs = None
983
+ response = None
984
+ num_tokens = -1
985
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
986
+ prompt=full_prompt,
987
+ temperature=temperature,
988
+ max_tokens=max_tokens,
989
+ # image_paths=image_paths,
990
+ )):
991
+ if isinstance(outputs, tuple):
992
+ response, num_tokens = outputs
993
+ else:
994
+ response, num_tokens = outputs, -1
995
+ yield response, num_tokens
996
+
997
+ if response is not None:
998
+ yield response, num_tokens
999
+
1000
+
1001
+
1002
+
1003
+ def vision_doc_chat_response_stream_multiturn_engine(
1004
+ history: List[Tuple[str, str]],
1005
+ temperature: float,
1006
+ max_tokens: int,
1007
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
1008
+ image_token: Optional[str] = IMAGE_TOKEN,
1009
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
1010
+ ):
1011
+ global MODEL_ENGINE
1012
+ temperature = float(temperature)
1013
+ # ! remove frequency_penalty
1014
+ # frequency_penalty = float(frequency_penalty)
1015
+ max_tokens = int(max_tokens)
1016
+ # ! skip safety
1017
+ if DATETIME_FORMAT in system_prompt:
1018
+ # ! This sometime works sometimes dont
1019
+ system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
1020
+ # ! history now can have multimodal
1021
+
1022
+ full_prompt, image_paths, conversations = gradio_history_to_vision_doc_conversation_prompt_paths(
1023
+ history=history, system_prompt=system_prompt, image_token=image_token, doc_instruction=doc_instruction
1024
+ )
1025
+
1026
+ # ! length check
1027
+ if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
1028
+ num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
1029
+ else:
1030
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
1031
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
1032
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
1033
+
1034
+ print(full_prompt)
1035
+ print(f'{image_paths=}')
1036
+ outputs = None
1037
+ response = None
1038
+ num_tokens = -1
1039
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
1040
+ prompt=full_prompt,
1041
+ temperature=temperature,
1042
+ max_tokens=max_tokens,
1043
+ image_paths=image_paths,
1044
+ )):
1045
+ if isinstance(outputs, tuple):
1046
+ response, num_tokens = outputs
1047
+ else:
1048
+ response, num_tokens = outputs, -1
1049
+ yield response, num_tokens
1050
+
1051
+ if response is not None:
1052
+ yield response, num_tokens
1053
+
1054
+
1055
+
1056
+ @register_demo
1057
+ class VisionChatInterfaceDemo(ChatInterfaceDemo):
1058
+ """
1059
+ Accept vision image
1060
+ """
1061
+
1062
+ @property
1063
+ def tab_name(self):
1064
+ return "Vision Chat"
1065
+
1066
+ @property
1067
+ def examples(self):
1068
+ return [
1069
+ ["What's strange about this image?", "assets/dog_monalisa.jpeg",],
1070
+ ["Explain why the sky is blue.", None,],
1071
+ ]
1072
+
1073
+ def create_demo(
1074
+ self,
1075
+ title: str | None = None,
1076
+ description: str | None = None,
1077
+ **kwargs
1078
+ ) -> gr.Blocks:
1079
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1080
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1081
+ temperature = kwargs.get("temperature", TEMPERATURE)
1082
+ model_name = kwargs.get("model_name", MODEL_NAME)
1083
+ description = description or """Upload an image to ask question about it."""
1084
+
1085
+ def add_multimodal_fn() -> List[Component]:
1086
+ image_input = gr.Image(label="Input Image", type="filepath", )
1087
+ return [image_input]
1088
+
1089
+ additional_inputs = [
1090
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1091
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1092
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1093
+ gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=20),
1094
+ ]
1095
+ def render_additional_inputs_fn():
1096
+ with Row():
1097
+ additional_inputs[0].render()
1098
+ additional_inputs[1].render()
1099
+ additional_inputs[3].render()
1100
+ additional_inputs[2].render()
1101
+
1102
+ demo_chat = MultiModalChatInterface(
1103
+ vision_chat_response_stream_multiturn_engine,
1104
+ chatbot=gr.Chatbot(
1105
+ label=model_name,
1106
+ bubble_full_width=False,
1107
+ latex_delimiters=[
1108
+ { "left": "$", "right": "$", "display": False},
1109
+ { "left": "$$", "right": "$$", "display": True},
1110
+ ],
1111
+ show_copy_button=True,
1112
+ layout="panel" if USE_PANEL else "bubble",
1113
+ height=CHATBOT_HEIGHT,
1114
+ ),
1115
+ # textbox=gr.Textbox(placeholder='Type message', lines=4, max_lines=128, min_width=200),
1116
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1117
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1118
+ # ! consider preventing the stop button
1119
+ # stop_btn=None,
1120
+ add_multimodal_fn=add_multimodal_fn,
1121
+ title=title,
1122
+ description=description,
1123
+ additional_inputs=additional_inputs,
1124
+ render_additional_inputs_fn=render_additional_inputs_fn,
1125
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1126
+ examples=self.examples,
1127
+ cache_examples=False,
1128
+ css=CSS,
1129
+ )
1130
+ return demo_chat
1131
+
1132
+
1133
+ def add_document_upload():
1134
+ file_input = gr.File(label='Upload pdf, docx, txt', file_count='single', file_types=['pdf', 'docx', 'txt'])
1135
+ # with Group():
1136
+ # file_input = gr.Textbox(value=None, label='Document path', lines=1, interactive=False)
1137
+ # upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt'], file_count="single")
1138
+ # upload_button.upload(lambda x: x.name, upload_button, file_input)
1139
+ return file_input
1140
+
1141
+
1142
+ @register_demo
1143
+ class DocChatInterfaceDemo(ChatInterfaceDemo):
1144
+ """
1145
+ Accept document (full length no RAG)
1146
+ """
1147
+ @property
1148
+ def tab_name(self):
1149
+ return "Doc Chat"
1150
+
1151
+ @property
1152
+ def examples(self):
1153
+ return [
1154
+ ["Summarize the document", "assets/attention_short.pdf",],
1155
+ ["Explain why the sky is blue.", None,],
1156
+ ]
1157
+
1158
+ def create_demo(
1159
+ self,
1160
+ title: str | None = None,
1161
+ description: str | None = None,
1162
+ **kwargs
1163
+ ) -> gr.Blocks:
1164
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1165
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1166
+ temperature = kwargs.get("temperature", TEMPERATURE)
1167
+ model_name = kwargs.get("model_name", MODEL_NAME)
1168
+ # frequence_penalty = FREQUENCE_PENALTY
1169
+ # presence_penalty = PRESENCE_PENALTY
1170
+ description = description or """Upload a short document to ask question about it."""
1171
+
1172
+ def add_multimodal_fn() -> List[Component]:
1173
+ file_input = add_document_upload()
1174
+ # image_input = gr.Image(label="Input Image", type="filepath", )
1175
+ return [file_input]
1176
+
1177
+ additional_inputs = [
1178
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1179
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1180
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1181
+ gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
1182
+ ]
1183
+ def render_additional_inputs_fn():
1184
+ with Row():
1185
+ additional_inputs[0].render()
1186
+ additional_inputs[1].render()
1187
+ additional_inputs[2].render()
1188
+ additional_inputs[3].render()
1189
+
1190
+ demo_chat = MultiModalChatInterface(
1191
+ doc_chat_response_stream_multiturn_engine,
1192
+ chatbot=gr.Chatbot(
1193
+ label=model_name,
1194
+ bubble_full_width=False,
1195
+ latex_delimiters=[
1196
+ { "left": "$", "right": "$", "display": False},
1197
+ { "left": "$$", "right": "$$", "display": True},
1198
+ ],
1199
+ show_copy_button=True,
1200
+ layout="panel" if USE_PANEL else "bubble",
1201
+ height=CHATBOT_HEIGHT,
1202
+ ),
1203
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1204
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1205
+ # ! consider preventing the stop button
1206
+ add_multimodal_fn=add_multimodal_fn,
1207
+ title=title,
1208
+ description=description,
1209
+ additional_inputs=additional_inputs,
1210
+ render_additional_inputs_fn=render_additional_inputs_fn,
1211
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1212
+ examples=self.examples,
1213
+ cache_examples=False,
1214
+ css=CSS,
1215
+ )
1216
+ return demo_chat
1217
+
1218
+
1219
+ @register_demo
1220
+ class VisionDocChatInterfaceDemo(ChatInterfaceDemo):
1221
+ """
1222
+ Accept either vision image or document (full length no RAG)
1223
+ """
1224
+ @property
1225
+ def tab_name(self):
1226
+ return "Vision Doc Chat"
1227
+
1228
+ @property
1229
+ def examples(self):
1230
+ return [
1231
+ ["What's strange about this image?", None, "assets/dog_monalisa.jpeg",],
1232
+ ["Summarize the document", "assets/attention_short.pdf", None,],
1233
+ ["Explain why the sky is blue.", None, None],
1234
+ ]
1235
+
1236
+ def create_demo(
1237
+ self,
1238
+ title: str | None = None,
1239
+ description: str | None = None,
1240
+ **kwargs
1241
+ ) -> gr.Blocks:
1242
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
1243
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
1244
+ temperature = kwargs.get("temperature", TEMPERATURE)
1245
+ model_name = kwargs.get("model_name", MODEL_NAME)
1246
+ # frequence_penalty = FREQUENCE_PENALTY
1247
+ # presence_penalty = PRESENCE_PENALTY
1248
+ description = description or """Upload either an image or short document to ask question about it."""
1249
+
1250
+ def add_multimodal_fn() -> List[Component]:
1251
+ file_input = add_document_upload()
1252
+ image_input = gr.Image(label="Input Image", type="filepath", )
1253
+ return [file_input, image_input]
1254
+
1255
+ additional_inputs = [
1256
+ gr.Number(value=temperature, label='Temperature', min_width=20),
1257
+ gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
1258
+ gr.Textbox(value=system_prompt, label='System prompt', lines=1),
1259
+ gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=2),
1260
+ gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
1261
+ ]
1262
+ def render_additional_inputs_fn():
1263
+ with Row():
1264
+ additional_inputs[0].render()
1265
+ additional_inputs[1].render()
1266
+ additional_inputs[3].render()
1267
+ additional_inputs[2].render()
1268
+ additional_inputs[4].render()
1269
+
1270
+ demo_chat = MultiModalChatInterface(
1271
+ vision_doc_chat_response_stream_multiturn_engine,
1272
+ chatbot=gr.Chatbot(
1273
+ label=MODEL_NAME,
1274
+ bubble_full_width=False,
1275
+ latex_delimiters=[
1276
+ { "left": "$", "right": "$", "display": False},
1277
+ { "left": "$$", "right": "$$", "display": True},
1278
+ ],
1279
+ show_copy_button=True,
1280
+ layout="panel" if USE_PANEL else "bubble",
1281
+ height=CHATBOT_HEIGHT,
1282
+ ),
1283
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
1284
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1285
+ add_multimodal_fn=add_multimodal_fn,
1286
+ title=title,
1287
+ description=description,
1288
+ additional_inputs=additional_inputs,
1289
+ render_additional_inputs_fn=render_additional_inputs_fn,
1290
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1291
+ examples=self.examples,
1292
+ cache_examples=False,
1293
+ css=CSS,
1294
+ )
1295
+ return demo_chat
multipurpose_chatbot/demos/multimodal_preference_interface.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ from gradio.components.base import Component
25
+
26
+ from .base_demo import register_demo, get_demo_class, BaseDemo
27
+
28
+
29
+ from .chat_interface import (
30
+ SYSTEM_PROMPT,
31
+ MODEL_NAME,
32
+ MAX_TOKENS,
33
+ TEMPERATURE,
34
+ CHAT_EXAMPLES,
35
+ gradio_history_to_openai_conversations,
36
+ gradio_history_to_conversation_prompt,
37
+ DATETIME_FORMAT,
38
+ get_datetime_string,
39
+ chat_response_stream_multiturn_engine,
40
+ ChatInterfaceDemo,
41
+ CustomizedChatInterface,
42
+ )
43
+
44
+ from gradio.events import Events
45
+
46
+ import inspect
47
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
48
+
49
+ import anyio
50
+ from gradio_client import utils as client_utils
51
+ from gradio_client.documentation import document
52
+
53
+ from gradio.blocks import Blocks
54
+ from gradio.components import (
55
+ Button,
56
+ Chatbot,
57
+ Component,
58
+ Markdown,
59
+ State,
60
+ Textbox,
61
+ get_component_instance,
62
+ )
63
+ from gradio.events import Dependency, on
64
+ from gradio.helpers import create_examples as Examples # noqa: N812
65
+ from gradio.helpers import special_args
66
+ from gradio.layouts import Accordion, Group, Row
67
+ from gradio.routes import Request
68
+ from gradio.themes import ThemeClass as Theme
69
+ from gradio.utils import SyncToAsyncIterator, async_iteration
70
+
71
+ from ..globals import MODEL_ENGINE
72
+
73
+ from ..configs import (
74
+ USE_PANEL,
75
+ IMAGE_TOKEN,
76
+ IMAGE_TOKEN_INTERACTIVE,
77
+ CHATBOT_HEIGHT,
78
+ ALLOWED_PATHS,
79
+ )
80
+
81
+
82
+ from .multimodal_chat_interface import (
83
+ DOC_INSTRUCTION,
84
+ DOC_TEMPLATE,
85
+ CSS,
86
+ undo_history,
87
+ undo_history_until_last_assistant_turn,
88
+ MultiModalChatInterface,
89
+ gradio_history_to_conversation_prompt,
90
+ gradio_history_to_openai_conversations,
91
+ gradio_history_to_vision_conversation_prompt_paths,
92
+ gradio_history_to_doc_conversation_prompt,
93
+ gradio_history_to_vision_doc_conversation_prompt_paths,
94
+ VisionChatInterfaceDemo,
95
+ vision_chat_response_stream_multiturn_engine,
96
+ )
97
+
98
+ import glob
99
+ from pathlib import Path
100
+ from gradio import utils as gradio_utils
101
+
102
+ PREF_DIR = os.environ.get("PREF_DIR", "./tmp")
103
+ PREFERENCE_MAKE_DATA_PATH = os.environ.get("PREFERENCE_MAKE_DATA_PATH", "assets/example_pref.json")
104
+
105
+ IMAGE_DIR = os.environ.get("IMAGE_DIR", "./tmp_image")
106
+
107
+ EXAMPLE_IMAGE_PATHS = [
108
+ x
109
+ for x in glob.glob(os.path.join(IMAGE_DIR, "*"))
110
+ ]
111
+ print(f'IMAGES: {EXAMPLE_IMAGE_PATHS[:3]=}')
112
+
113
+
114
+ # ! Existing images
115
+
116
+ IMAGE_GLOB_ROOT = "/mnt/workspace/workgroup/phi/raw_data/multimodal_seallm/processed/sft/dpo_examples"
117
+ # ALLOWED_PATHS.append(IMAGE_GLOB_ROOT)
118
+ IMAGE_GLOBS = {
119
+ # "geometry": "geo3k/train/*/img_diagram.png",
120
+ "Geometry": ["geoqa_plus/*png", "Ask question about to solve the puzzle, calculating angles, find values, ... Provide extra information in the question (e.g 'Angle 1 = 30 degrees, find angle 2 from image.')"],
121
+ "Everyday": ["gqa/images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories ...."],
122
+ "OCR (read text)": ["ocr_vqa/images/*", "Ask question (1) full OCR description, (2) read specific details (e.g 'Who wrote the book?')."],
123
+ "OpenViVQA": ["OpenViVQA/training-images/*", "Only vietnamese, (1) full OCR description, (2) read specific details, (3) image description and question answering"],
124
+ "Text-VQA": ["textvqa/train_images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories, (5) reasoning"],
125
+ "Landmarks": ["web-landmark/images/*", "Ask question to (1) Where is landmarks (2) What to do at that place (3) Write stories, (4) give advise for tourists..."],
126
+ "Everyday-VG2": ["vg/VG_100K_2/*", "Same with Everyday"],
127
+ }
128
+
129
+ IMAGE_CUT_OFF_BEGIN = 0
130
+ IMAGE_CUT_OFF = 100
131
+ # IMAGE_CUT_OFF = 20
132
+
133
+ IMAGE_GLOB_PATHS = {}
134
+ IMAGE_GLOB_DESCS = {}
135
+ for k, v in IMAGE_GLOBS.items():
136
+ glob_p, description = v
137
+ paths = []
138
+ for i, p in enumerate(glob.glob(os.path.join(IMAGE_GLOB_ROOT, glob_p))):
139
+ if i < IMAGE_CUT_OFF_BEGIN:
140
+ continue
141
+ if i >= IMAGE_CUT_OFF + IMAGE_CUT_OFF_BEGIN:
142
+ break
143
+ paths.append(p)
144
+ IMAGE_GLOB_PATHS[k] = paths
145
+ IMAGE_GLOB_DESCS[k] = description
146
+
147
+ print(IMAGE_GLOB_PATHS['Geometry'][:10])
148
+
149
+
150
+ def read_json(json_file):
151
+ print(f'Reading : {json_file}')
152
+ with open(json_file, 'r', encoding='utf-8') as f:
153
+ rows = json.load(f)
154
+ return rows
155
+
156
+
157
+ def write_json(data, json_file):
158
+ with open(json_file, 'w', encoding='utf-8') as f:
159
+ json.dump(data, f, indent=4, ensure_ascii=False)
160
+
161
+
162
+ def convert_pref_data_to_openai_format(rows_dict):
163
+ for key, r in rows_dict.items():
164
+ if "conversation_prefix" in r:
165
+ assert "responses" in r, f'invalid: {r}'
166
+ continue
167
+ history = r['history']
168
+ conversations = []
169
+ for user, assistant in history:
170
+ conversations.append({"role": "user", "content": user.strip()})
171
+ conversations.append({"role": "assistant", "content": assistant.strip()})
172
+ r['conversation_prefix'] = conversations[:-1]
173
+ r['responses'] = [conversations[-1]]
174
+ r['original_response'] = conversations[-1]
175
+ if "lang" not in r:
176
+ r['lang'] = key[-2:]
177
+ # missing an item in responses
178
+ lang_set = list(set([r['lang'] for r in rows_dict.values()]))
179
+ return rows_dict, lang_set
180
+
181
+
182
+ def convert_mm_pref_data_to_openai_format(rows_dict):
183
+ pass
184
+
185
+
186
+ PREFERENCE_RATE_DICT = None
187
+ LANG_SET = ["en", "vi", "id", 'ms', "th", "zh", 'lo', 'km', 'tl', 'my']
188
+ if PREFERENCE_MAKE_DATA_PATH is not None and os.path.exists(PREFERENCE_MAKE_DATA_PATH):
189
+ print(f'Loading {PREFERENCE_MAKE_DATA_PATH}')
190
+ PREFERENCE_RATE_DICT = read_json(PREFERENCE_MAKE_DATA_PATH)
191
+ PREFERENCE_RATE_DICT, _LANG_SET = convert_pref_data_to_openai_format(PREFERENCE_RATE_DICT)
192
+ LANG_SET = LANG_SET + [l for l in _LANG_SET if l not in LANG_SET]
193
+
194
+
195
+
196
+
197
+
198
+ @document()
199
+ class CustomJsonlLogger(gr.FlaggingCallback):
200
+ def __init__(self):
201
+ self.num_lines = 0
202
+
203
+ def setup(
204
+ self,
205
+ components: list[Component],
206
+ flagging_dir: Union[str, Path],
207
+ ):
208
+ self.components = components
209
+ self.flagging_dir = flagging_dir
210
+ os.makedirs(flagging_dir, exist_ok=True)
211
+ flagging_dir = self.flagging_dir
212
+ log_filepath = Path(flagging_dir) / "log.jsonl"
213
+ if Path(log_filepath).exists():
214
+ with open(log_filepath, "rb") as f:
215
+ self.num_lines = sum(1 for _ in f)
216
+ else:
217
+ self.num_lines = 0
218
+
219
+ def flag(
220
+ self,
221
+ flag_data: list[Any],
222
+ flag_option: str = "",
223
+ username: Union[str, None] = None,
224
+ ) -> int:
225
+ import datetime
226
+ flagging_dir = self.flagging_dir
227
+ log_filepath = Path(flagging_dir) / "log.jsonl"
228
+ is_new = not Path(log_filepath).exists()
229
+ headers = [
230
+ getattr(component, "label", None) or f"component {idx}"
231
+ for idx, component in enumerate(self.components)
232
+ ] + [
233
+ "flag",
234
+ "username",
235
+ "timestamp",
236
+ ]
237
+
238
+ csv_data = []
239
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
240
+ save_dir = Path(
241
+ flagging_dir
242
+ ) / client_utils.strip_invalid_filename_characters(
243
+ getattr(component, "label", None) or f"component {idx}"
244
+ )
245
+ if gradio_utils.is_update(sample):
246
+ csv_data.append(str(sample))
247
+ else:
248
+ csv_data.append(
249
+ component.flag(sample, flag_dir=save_dir)
250
+ if sample is not None
251
+ else ""
252
+ )
253
+ csv_data.append(flag_option)
254
+ csv_data.append(username if username is not None else "")
255
+ csv_data.append(str(datetime.datetime.now()))
256
+
257
+ json_obj = {}
258
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
259
+ save_dir = Path(
260
+ flagging_dir
261
+ ) / client_utils.strip_invalid_filename_characters(
262
+ getattr(component, "label", None) or f"component {idx}"
263
+ )
264
+ label = getattr(component, "label", None) or f"component {idx}"
265
+ if gradio_utils.is_update(sample):
266
+ value = str(sample)
267
+ else:
268
+ value = component.flag(sample, flag_dir=save_dir) if sample is not None else None
269
+ json_obj[label] = value
270
+
271
+ json_obj['flag'] = flag_option
272
+ json_obj['username'] = username if username is not None else ""
273
+ json_obj['timestamp'] = str(datetime.datetime.now())
274
+
275
+ with open(log_filepath, "a", encoding="utf-8") as jsonl_file:
276
+ jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
277
+
278
+ self.num_lines += 1
279
+ return self.num_lines
280
+
281
+ @document()
282
+ class VisionJsonlLogger(CustomJsonlLogger):
283
+ # ! must save the image
284
+ def flag(
285
+ self,
286
+ flag_data: list[Any],
287
+ flag_option: str = "",
288
+ username: Union[str, None] = None,
289
+ ) -> int:
290
+ import datetime
291
+ from shutil import copyfile
292
+ flagging_dir = self.flagging_dir
293
+ log_filepath = Path(flagging_dir) / "log.jsonl"
294
+ image_dir = Path(flagging_dir) / "images"
295
+ is_new = not Path(log_filepath).exists()
296
+ os.makedirs(image_dir, exist_ok=True)
297
+ headers = [
298
+ getattr(component, "label", None) or f"component {idx}"
299
+ for idx, component in enumerate(self.components)
300
+ ] + [
301
+ "flag",
302
+ "username",
303
+ "timestamp",
304
+ ]
305
+
306
+ csv_data = []
307
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
308
+ save_dir = Path(
309
+ flagging_dir
310
+ ) / client_utils.strip_invalid_filename_characters(
311
+ getattr(component, "label", None) or f"component {idx}"
312
+ )
313
+ if gradio_utils.is_update(sample):
314
+ csv_data.append(str(sample))
315
+ else:
316
+ csv_data.append(
317
+ component.flag(sample, flag_dir=save_dir)
318
+ if sample is not None
319
+ else ""
320
+ )
321
+ csv_data.append(flag_option)
322
+ csv_data.append(username if username is not None else "")
323
+ csv_data.append(str(datetime.datetime.now()))
324
+
325
+ json_obj = {}
326
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
327
+ save_dir = Path(
328
+ flagging_dir
329
+ ) / client_utils.strip_invalid_filename_characters(
330
+ getattr(component, "label", None) or f"component {idx}"
331
+ )
332
+ label = getattr(component, "label", None) or f"component {idx}"
333
+ if gradio_utils.is_update(sample):
334
+ value = str(sample)
335
+ else:
336
+ value = component.flag(sample, flag_dir=save_dir) if sample is not None else None
337
+ if isinstance(value, list):
338
+ # Expecting history
339
+ from .multimodal_chat_interface import gradio_history_to_vision_conversations_paths
340
+ conversations, image_paths = gradio_history_to_vision_conversations_paths(value)
341
+ new_paths = [
342
+ os.path.join(image_dir, str(datetime.datetime.now()) + os.path.basename(p))
343
+ for p in image_paths
344
+ ]
345
+ for np, ip in zip(new_paths, image_paths):
346
+ copyfile(ip, np)
347
+ json_obj[label] = conversations
348
+ json_obj[label + "-images"] = new_paths
349
+ else:
350
+ json_obj[label] = value
351
+
352
+ json_obj['flag'] = flag_option
353
+ json_obj['username'] = username if username is not None else ""
354
+ json_obj['timestamp'] = str(datetime.datetime.now())
355
+
356
+ with open(log_filepath, "a", encoding="utf-8") as jsonl_file:
357
+ jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
358
+
359
+ self.num_lines += 1
360
+ return self.num_lines
361
+
362
+
363
+
364
+
365
+
366
+ def get_preference_radio():
367
+ pref_choice = gr.Radio(
368
+ ['1 Better', '2 Better', 'Add best', 'dirty/undecided'],
369
+ label='preference',
370
+ info="Indicate if 1 or 2 is better. If both not excellent, pick 'Add best' and write the better one below. If question or answer is problematic, cannot decide, then choose dirty/undecided."
371
+ )
372
+ return pref_choice
373
+
374
+
375
+
376
+ def vision_submit_vision_response_stream_multiturn_engine_yhistory(
377
+ message: str,
378
+ input_image: str,
379
+ history: List[List[str]],
380
+ temperature: float,
381
+ max_tokens: int,
382
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
383
+ image_token: Optional[str] = IMAGE_TOKEN,
384
+ ):
385
+ # ! Add message and input_image into the history and submit
386
+ message = message.strip()
387
+ if message == "":
388
+ gr.Warning(f'Input text cannot be empty')
389
+ yield history
390
+
391
+ new_history = history
392
+ if input_image is not None and os.path.exists(input_image):
393
+ # ! image exist, so add message if it's not empty
394
+ new_history = new_history + [[(input_image,), None]]
395
+ if message != "":
396
+ new_history = new_history + [[message, None]]
397
+ else:
398
+ # ! message cannot be empty if there is no input_image
399
+ if message == "":
400
+ gr.Warning(f'Input text cannot be empty!')
401
+ yield history
402
+ return
403
+ else:
404
+ new_history = new_history + [[message, None]]
405
+
406
+ yield new_history
407
+
408
+ # ! yield current history
409
+ # use vision_chat_response_stream_multiturn_engine
410
+ response = None
411
+ for response, num_tokens in vision_chat_response_stream_multiturn_engine(
412
+ history=new_history,
413
+ temperature=temperature, max_tokens=max_tokens, system_prompt=system_prompt,
414
+ image_token=image_token,
415
+ ):
416
+ yield new_history[:-1] + [[message, response]]
417
+
418
+ if response is not None:
419
+ yield new_history[:-1] + [[message, response]]
420
+
421
+
422
+ def vision_submit_2_histories(
423
+ message: str,
424
+ input_image: str,
425
+ history1: List[List[str]],
426
+ history2: List[List[str]],
427
+ temperature: float,
428
+ max_tokens: int,
429
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
430
+ image_token: Optional[str] = IMAGE_TOKEN,
431
+ ):
432
+ # need to yield 2 history
433
+ new_history1 = history1
434
+ new_history2 = history2
435
+ for his in vision_submit_vision_response_stream_multiturn_engine_yhistory(
436
+ message, input_image, history1, temperature, max_tokens, system_prompt, image_token,
437
+ ):
438
+ new_history1 = his
439
+ yield new_history1, new_history2
440
+
441
+ for his in vision_submit_vision_response_stream_multiturn_engine_yhistory(
442
+ message, input_image, history2, temperature, max_tokens, system_prompt, image_token,
443
+ ):
444
+ new_history2 = his
445
+ yield new_history1, new_history2
446
+
447
+
448
+ def undo_history_until_last_assistant_turn_message(history):
449
+ history = undo_history(history)
450
+ while len(history) > 0 and history[-1][-1] is None:
451
+ history = undo_history(history)
452
+ return history, history
453
+
454
+
455
+
456
+ def replace_last_response(input_text: str, history: List[Tuple[str, str]]):
457
+ # replace the last response with input_text
458
+ input_text = input_text.strip()
459
+ if input_text == "":
460
+ gr.Warning(f'prompt empty! dont send empty prompt')
461
+ return "", history
462
+ if len(history) == 0:
463
+ gr.Warning(f'History empty, cannot replace')
464
+ return input_text, history
465
+ history[-1][-1] = input_text
466
+ return "", history
467
+
468
+
469
+ # def load_image_from_gallery(selected_state: gr.SelectData):
470
+ # convo = sft_data_list[selected_state.index]
471
+ # dirname = sft_dirname
472
+ # image_path = os.path.join(dirname, convo['image'])
473
+ # return image_path
474
+
475
+ def load_image_from_gallery(data_list, selected_state: gr.SelectData):
476
+ image_path = data_list[selected_state.index]
477
+ # dirname = sft_dirname
478
+ # image_path = os.path.join(dirname, convo['image'])
479
+ return image_path
480
+
481
+
482
+ @register_demo
483
+ class VisionLivePreferencePickDemo(VisionChatInterfaceDemo):
484
+ @property
485
+ def examples(self):
486
+ return [
487
+ ["What's strange about this image?", "assets/dog_monalisa.jpeg",],
488
+ ["Explain why the sky is blue.", None,],
489
+ ]
490
+
491
+ @property
492
+ def tab_name(self):
493
+ return "Vision Live Preference"
494
+
495
+ def create_demo(
496
+ self,
497
+ title: str | None = None,
498
+ description: str | None = None,
499
+ **kwargs
500
+ ) -> gr.Blocks:
501
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
502
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
503
+ temperature = kwargs.get("temperature", TEMPERATURE)
504
+ model_name = kwargs.get("model_name", MODEL_NAME)
505
+
506
+ log_folder = os.path.join(PREF_DIR, "live_preference_pick")
507
+ description = f"""
508
+ ## Live generation preference picking
509
+ Live generation is similar to the Preference Picking demo, except that linguists can come up with questions/prompts **on their own** instead of pre-existing data.
510
+
511
+ PREF_DIR: {log_folder}
512
+ """
513
+
514
+ instruction_content = f"""
515
+ ### Tasks
516
+ You are enabled to freely build 2 different conversations using the model and pick the better conversations.
517
+ You can also create best responses if model's generated ones are not good.
518
+
519
+ ### Requirements
520
+ The 2 conversations must share at least the first user query. Other than that, the length, number of turns, user queries (except the first one) can vary.
521
+ For example:
522
+ ```
523
+ # Valid conversation pairs
524
+ "User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136"
525
+ -> "Bot: I dont know"
526
+
527
+ "User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136"
528
+ -> "Bot: 1+1=3" -> "User: that's wrong!" -> "Bot: Im sorry man."
529
+ ```
530
+
531
+ ```
532
+ # Invalid pairs:
533
+ "User: Hello, 1+1=?" -> "Bot: 1+1=2"
534
+ "User: Tell me a joke" -> "Bot: here is the joke for your..."
535
+ ```
536
+
537
+ ### Steps to proceed:
538
+ There are multiple buttons:
539
+ * `Submit both`: Submit the text prompt to both chatboxes, expect different (or same) answers.
540
+ * `Regenerate`: Regenerate the responses of both chatboxes from the last user queries.
541
+ * `Clear`: Clear both chatboxes.
542
+
543
+ The following numbered buttons (1 or 2) is applied to only Bot-1 or Bot-2 respectively.
544
+ * `Submit-1`: Submit the text prompt only one chatbot (1 or 2).
545
+ * `Undo-1`: Undo the last generation (both last response and query)
546
+ * `Regen-1`: Regenerate the last response.
547
+ * `Replace-1`: Replace the last response with a better response (in case the last response is incorrect, unsatisfactory)
548
+
549
+ """
550
+ callback = VisionJsonlLogger()
551
+ with gr.Blocks(css=CSS) as pdemo:
552
+ gr.Markdown(description)
553
+
554
+ with gr.Accordion(label="Instructions and Guidelines", open=False):
555
+ gr.Markdown(instruction_content)
556
+
557
+ with gr.Accordion(label="Additional input", open=False):
558
+ temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
559
+ length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
560
+ # freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
561
+ # pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
562
+ # stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation.', lines=1)
563
+ system_prompt = gr.Textbox(value=system_prompt, label='system_prompt', lines=1)
564
+
565
+
566
+ with gr.Row():
567
+ chatbot_1 = gr.Chatbot(
568
+ [],
569
+ label="Bot-1",
570
+ elem_id="chatbot-1",
571
+ bubble_full_width=False,
572
+ latex_delimiters=[
573
+ # { "left": "$", "right": "$", "display": False},
574
+ { "left": "$$", "right": "$$", "display": True},
575
+ ],
576
+ show_copy_button=True,
577
+ layout="panel" if USE_PANEL else "bubble",
578
+ height=CHATBOT_HEIGHT,
579
+ )
580
+ chatbot_2 = gr.Chatbot(
581
+ [],
582
+ label="Bot-2",
583
+ elem_id="chatbot-2",
584
+ bubble_full_width=False,
585
+ latex_delimiters=[
586
+ # { "left": "$", "right": "$", "display": False},
587
+ { "left": "$$", "right": "$$", "display": True},
588
+ ],
589
+ show_copy_button=True,
590
+ layout="panel" if USE_PANEL else "bubble",
591
+ height=CHATBOT_HEIGHT,
592
+ )
593
+
594
+ with gr.Row():
595
+ input_text = gr.Textbox(
596
+ scale=6,
597
+ lines=12,
598
+ # lines=4,
599
+ max_lines=40,
600
+ show_label=False,
601
+ placeholder="Enter text and press enter, or upload an image",
602
+ container=False,
603
+ )
604
+ # submit will submit the same input text to both responses
605
+ input_image = gr.Image(
606
+ label="input_image", type="filepath", scale=3,
607
+ # height=250,
608
+ )
609
+ with gr.Row():
610
+ gen_submit = gr.Button('Send both', scale=1, variant='primary')
611
+ # regenerate should not care about input_text, it just undo the previous history
612
+ # regen_submit = gr.Button('Regenerate', scale=1)
613
+ clear_btn = gr.Button('Clear', scale=1)
614
+ # submit
615
+ with gr.Row():
616
+ chat1_submit = gr.Button('Send-1', variant='primary')
617
+ chat1_undo = gr.Button('Undo-1')
618
+ # chat1_regenerate = gr.Button('Regen-1')
619
+ chat1_replace = gr.Button('Replace-1')
620
+
621
+ chat2_submit = gr.Button('Send-2', variant='primary')
622
+ chat2_undo = gr.Button('Undo-2')
623
+ # chat2_regenerate = gr.Button('Regen-2')
624
+ chat2_replace = gr.Button('Replace-2')
625
+ gr.Markdown(f'**Do not click `Record Choice` twice with the same data sample!**')
626
+ with gr.Row():
627
+ pref_choice = get_preference_radio()
628
+
629
+ # with gr.Row():
630
+ # text_replace = gr.Textbox(
631
+ # placeholder="If both responses are not good, write a better response here. Only apply to the last response.",
632
+ # lines=2,
633
+ # max_lines=30,
634
+ # scale=6,
635
+ # label="best_response"
636
+ # )
637
+ submit_choice_btn = gr.Button('Record Choice', variant='secondary')
638
+
639
+
640
+ from functools import partial
641
+
642
+ with gr.Row():
643
+ gr.Examples(
644
+ label="Random images",
645
+ examples=[[x] for x in EXAMPLE_IMAGE_PATHS],
646
+ inputs=input_image,
647
+ cache_examples=False,
648
+ examples_per_page=100,
649
+ )
650
+
651
+ for k, plist in IMAGE_GLOB_PATHS.items():
652
+ print(f'{k}: {plist[:5]}')
653
+ gr.Markdown(f"{k}: {IMAGE_GLOB_DESCS[k]}")
654
+ gallery = gr.Gallery(
655
+ label=k,
656
+ value=plist,
657
+ allow_preview=False,
658
+ columns=10,
659
+ # rows=2,
660
+ height=250,
661
+ )
662
+ def _load_image_from_gallery(selected_state: gr.SelectData):
663
+ image_path = selected_state.value['image']['path']
664
+ print(f'Select: {image_path}')
665
+ return image_path
666
+ gallery.select(
667
+ _load_image_from_gallery,
668
+ # lambda select: plist[select.index],
669
+ # inputs=,
670
+ outputs=[input_image],
671
+ queue=False
672
+ )
673
+
674
+ # ! events for submit choices
675
+ submit_choice_btn.click(
676
+ lambda: gr.Button(value="Saving...", interactive=False, variant='stop'),
677
+ None,
678
+ submit_choice_btn,
679
+ queue=False,
680
+ api_name=False,
681
+ )
682
+ visual_feedback = True
683
+ def flag_method(request: gr.Request, *args):
684
+ # ! must save the image somewhere
685
+ try:
686
+ callback.flag(args)
687
+ except Exception as e:
688
+ print(f"Error while flagging: {e}")
689
+ if visual_feedback:
690
+ return "Error!"
691
+ if not visual_feedback:
692
+ return
693
+ gr.Info(f'Saving preference sucessful ({args[0]})')
694
+ time.sleep(1) # to provide enough time for the user to observe button change
695
+ return gr.Button(value="Record Choice", interactive=True)
696
+
697
+ callback.setup([chatbot_1, chatbot_2, pref_choice], log_folder)
698
+ submit_choice_btn.click(
699
+ flag_method, [chatbot_1, chatbot_2, pref_choice], submit_choice_btn,
700
+ preprocess=False, queue=False, api_name=False
701
+ )
702
+
703
+ # ! button evenrs
704
+ from gradio.events import Dependency, EventListenerMethod, on
705
+ generate_sub_events_both = [input_text.submit, gen_submit.click]
706
+ on(
707
+ generate_sub_events_both,
708
+ vision_submit_2_histories,
709
+ [
710
+ input_text, input_image, chatbot_1, chatbot_2,
711
+ temp, length, system_prompt
712
+ ],
713
+ [chatbot_1, chatbot_2],
714
+ api_name=False,
715
+ queue=True,
716
+ ).then(
717
+ lambda mes, img: ("", None),
718
+ [input_text, input_image],
719
+ [input_text, input_image],
720
+ api_name=False,
721
+ queue=False,
722
+ )
723
+ clear_btn.click(
724
+ lambda c1, c2, txt, img: ([], [], "", None),
725
+ [chatbot_1, chatbot_2, input_text, input_image],
726
+ [chatbot_1, chatbot_2, input_text, input_image],
727
+ api_name=False,
728
+ queue=True,
729
+ )
730
+ chat1_submit.click(
731
+ vision_submit_vision_response_stream_multiturn_engine_yhistory,
732
+ [
733
+ input_text, input_image, chatbot_1,
734
+ temp, length, system_prompt,
735
+ ],
736
+ [chatbot_1],
737
+ api_name=False,
738
+ queue=True,
739
+ ).then(
740
+ lambda mes, img: ("", None),
741
+ [input_text, input_image],
742
+ [input_text, input_image],
743
+ api_name=False,
744
+ queue=False,
745
+ )
746
+ chat2_submit.click(
747
+ vision_submit_vision_response_stream_multiturn_engine_yhistory,
748
+ [
749
+ input_text, input_image, chatbot_2,
750
+ temp, length, system_prompt,
751
+ ],
752
+ [chatbot_2],
753
+ api_name=False,
754
+ queue=True,
755
+ ).then(
756
+ lambda mes, img: ("", None),
757
+ [input_text, input_image],
758
+ [input_text, input_image],
759
+ api_name=False,
760
+ queue=False,
761
+ )
762
+ chat1_undo.click(
763
+ undo_history_until_last_assistant_turn,
764
+ chatbot_1,
765
+ [chatbot_1, input_text],
766
+ api_name=False,
767
+ queue=True,
768
+ )
769
+ chat2_undo.click(
770
+ undo_history_until_last_assistant_turn,
771
+ chatbot_2,
772
+ [chatbot_2, input_text],
773
+ api_name=False,
774
+ queue=True,
775
+ )
776
+ chat1_replace.click(
777
+ replace_last_response,
778
+ [input_text, chatbot_1],
779
+ [input_text, chatbot_1],
780
+ api_name=False,
781
+ queue=True,
782
+ )
783
+ chat2_replace.click(
784
+ replace_last_response,
785
+ [input_text, chatbot_2],
786
+ [input_text, chatbot_2],
787
+ api_name=False,
788
+ queue=True,
789
+ )
790
+
791
+
792
+
793
+
794
+ return pdemo
multipurpose_chatbot/demos/rag_chat_interface.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ from gradio.themes import ThemeClass as Theme
25
+
26
+ from .base_demo import register_demo, get_demo_class, BaseDemo
27
+
28
+ import inspect
29
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
30
+
31
+ import anyio
32
+ from gradio_client import utils as client_utils
33
+ from gradio_client.documentation import document
34
+
35
+ from gradio.blocks import Blocks
36
+ from gradio.components import (
37
+ Button,
38
+ Chatbot,
39
+ Component,
40
+ Markdown,
41
+ State,
42
+ Textbox,
43
+ get_component_instance,
44
+ )
45
+ from gradio.events import Dependency, on
46
+ from gradio.helpers import create_examples as Examples # noqa: N812
47
+ from gradio.helpers import special_args
48
+ from gradio.layouts import Accordion, Group, Row
49
+ from gradio.routes import Request
50
+ from gradio.themes import ThemeClass as Theme
51
+ from gradio.utils import SyncToAsyncIterator, async_iteration
52
+
53
+
54
+ from ..globals import MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, load_embeddings, get_rag_embeddings
55
+
56
+ from .chat_interface import (
57
+ SYSTEM_PROMPT,
58
+ MODEL_NAME,
59
+ MAX_TOKENS,
60
+ TEMPERATURE,
61
+ CHAT_EXAMPLES,
62
+ gradio_history_to_openai_conversations,
63
+ gradio_history_to_conversation_prompt,
64
+ DATETIME_FORMAT,
65
+ get_datetime_string,
66
+ format_conversation,
67
+ chat_response_stream_multiturn_engine,
68
+ ChatInterfaceDemo,
69
+ CustomizedChatInterface,
70
+ )
71
+
72
+ from ..configs import (
73
+ CHUNK_SIZE,
74
+ CHUNK_OVERLAP,
75
+ RAG_EMBED_MODEL_NAME,
76
+ )
77
+
78
+ RAG_CURRENT_VECTORSTORE = None
79
+
80
+
81
+ def load_document_split_vectorstore(file_path):
82
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
83
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
84
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
85
+ from langchain_community.vectorstores import Chroma, FAISS
86
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
87
+ splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
88
+ if file_path.endswith('.pdf'):
89
+ loader = PyPDFLoader(file_path)
90
+ elif file_path.endswith('.docx'):
91
+ loader = Docx2txtLoader(file_path)
92
+ elif file_path.endswith('.txt'):
93
+ loader = TextLoader(file_path)
94
+ splits = loader.load_and_split(splitter)
95
+ RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
96
+ return RAG_CURRENT_VECTORSTORE
97
+
98
+ def docs_to_context_content(docs: List[Any]):
99
+ content = "\n".join([d.page_content for d in docs])
100
+ return content
101
+
102
+
103
+ DOC_TEMPLATE = """###
104
+ {content}
105
+ ###
106
+
107
+ """
108
+
109
+ DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
110
+ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
111
+ """
112
+
113
+
114
+ def docs_to_rag_context(docs: List[Any], doc_instruction=None):
115
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
116
+ content = docs_to_context_content(docs)
117
+ context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=content)
118
+ return context
119
+
120
+
121
+ def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
122
+ doc_context = None
123
+ if file_input is not None:
124
+ if file_input == RAG_CURRENT_FILE:
125
+ # reuse
126
+ vectorstore = RAG_CURRENT_VECTORSTORE
127
+ print(f'Reuse vectorstore: {file_input}')
128
+ else:
129
+ vectorstore = load_document_split_vectorstore(file_input)
130
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
131
+ RAG_CURRENT_FILE = file_input
132
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
133
+ doc_context = docs_to_rag_context(docs)
134
+ return doc_context
135
+
136
+
137
+ def chat_response_stream_multiturn_doc_engine(
138
+ message: str,
139
+ history: List[Tuple[str, str]],
140
+ file_input: Optional[str] = None,
141
+ temperature: float = 0.7,
142
+ max_tokens: int = 1024,
143
+ system_prompt: Optional[str] = SYSTEM_PROMPT,
144
+ rag_num_docs: Optional[int] = 3,
145
+ doc_instruction: Optional[str] = DOC_INSTRUCTION,
146
+ # profile: Optional[gr.OAuthProfile] = None,
147
+ ):
148
+ global MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
149
+ if len(message) == 0:
150
+ raise gr.Error("The message cannot be empty!")
151
+
152
+ rag_num_docs = int(rag_num_docs)
153
+ doc_instruction = doc_instruction or DOC_INSTRUCTION
154
+ doc_context = None
155
+ if file_input is not None:
156
+ if file_input == RAG_CURRENT_FILE:
157
+ # reuse
158
+ vectorstore = RAG_CURRENT_VECTORSTORE
159
+ print(f'Reuse vectorstore: {file_input}')
160
+ else:
161
+ vectorstore = load_document_split_vectorstore(file_input)
162
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
163
+ RAG_CURRENT_FILE = file_input
164
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
165
+ # doc_context = docs_to_rag_context(docs)
166
+ rag_content = docs_to_context_content(docs)
167
+ doc_context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=rag_content)
168
+
169
+ if doc_context is not None:
170
+ message = f"{doc_context}\n\n{message}"
171
+
172
+ for response, num_tokens in chat_response_stream_multiturn_engine(
173
+ message, history, temperature, max_tokens, system_prompt
174
+ ):
175
+ # ! yield another content which is doc_context
176
+ yield response, num_tokens, doc_context
177
+
178
+
179
+
180
+ class RagChatInterface(CustomizedChatInterface):
181
+ def __init__(
182
+ self,
183
+ fn: Callable[..., Any],
184
+ *,
185
+ chatbot: gr.Chatbot | None = None,
186
+ textbox: gr.Textbox | None = None,
187
+ additional_inputs: str | Component | list[str | Component] | None = None,
188
+ additional_inputs_accordion_name: str | None = None,
189
+ additional_inputs_accordion: str | gr.Accordion | None = None,
190
+ render_additional_inputs_fn: Callable | None = None,
191
+ examples: list[str] | None = None,
192
+ cache_examples: bool | None = None,
193
+ title: str | None = None,
194
+ description: str | None = None,
195
+ theme: Theme | str | None = None,
196
+ css: str | None = None,
197
+ js: str | None = None,
198
+ head: str | None = None,
199
+ analytics_enabled: bool | None = None,
200
+ submit_btn: str | Button | None = "Submit",
201
+ stop_btn: str | Button | None = "Stop",
202
+ retry_btn: str | Button | None = "🔄 Retry",
203
+ undo_btn: str | Button | None = "↩️ Undo",
204
+ clear_btn: str | Button | None = "🗑️ Clear",
205
+ autofocus: bool = True,
206
+ concurrency_limit: int | Literal['default'] | None = "default",
207
+ fill_height: bool = True
208
+ ):
209
+ try:
210
+ super(gr.ChatInterface, self).__init__(
211
+ analytics_enabled=analytics_enabled,
212
+ mode="chat_interface",
213
+ css=css,
214
+ title=title or "Gradio",
215
+ theme=theme,
216
+ js=js,
217
+ head=head,
218
+ fill_height=fill_height,
219
+ )
220
+ except Exception as e:
221
+ # Handling some old gradio version with out fill_height
222
+ super(gr.ChatInterface, self).__init__(
223
+ analytics_enabled=analytics_enabled,
224
+ mode="chat_interface",
225
+ css=css,
226
+ title=title or "Gradio",
227
+ theme=theme,
228
+ js=js,
229
+ head=head,
230
+ # fill_height=fill_height,
231
+ )
232
+ self.concurrency_limit = concurrency_limit
233
+ self.fn = fn
234
+ self.render_additional_inputs_fn = render_additional_inputs_fn
235
+ self.is_async = inspect.iscoroutinefunction(
236
+ self.fn
237
+ ) or inspect.isasyncgenfunction(self.fn)
238
+ self.is_generator = inspect.isgeneratorfunction(
239
+ self.fn
240
+ ) or inspect.isasyncgenfunction(self.fn)
241
+ self.examples = examples
242
+ if self.space_id and cache_examples is None:
243
+ self.cache_examples = True
244
+ else:
245
+ self.cache_examples = cache_examples or False
246
+ self.buttons: list[Button | None] = []
247
+
248
+ if additional_inputs:
249
+ if not isinstance(additional_inputs, list):
250
+ additional_inputs = [additional_inputs]
251
+ self.additional_inputs = [
252
+ get_component_instance(i)
253
+ for i in additional_inputs # type: ignore
254
+ ]
255
+ else:
256
+ self.additional_inputs = []
257
+ if additional_inputs_accordion_name is not None:
258
+ print(
259
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
260
+ )
261
+ self.additional_inputs_accordion_params = {
262
+ "label": additional_inputs_accordion_name
263
+ }
264
+ if additional_inputs_accordion is None:
265
+ self.additional_inputs_accordion_params = {
266
+ "label": "Additional Inputs",
267
+ "open": False,
268
+ }
269
+ elif isinstance(additional_inputs_accordion, str):
270
+ self.additional_inputs_accordion_params = {
271
+ "label": additional_inputs_accordion
272
+ }
273
+ elif isinstance(additional_inputs_accordion, Accordion):
274
+ self.additional_inputs_accordion_params = (
275
+ additional_inputs_accordion.recover_kwargs(
276
+ additional_inputs_accordion.get_config()
277
+ )
278
+ )
279
+ else:
280
+ raise ValueError(
281
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
282
+ )
283
+
284
+ with self:
285
+ if title:
286
+ Markdown(
287
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
288
+ )
289
+ if description:
290
+ Markdown(description)
291
+
292
+ if chatbot:
293
+ self.chatbot = chatbot.render()
294
+ else:
295
+ self.chatbot = Chatbot(
296
+ label="Chatbot", scale=1, height=200 if fill_height else None
297
+ )
298
+
299
+ with Row():
300
+ for btn in [retry_btn, undo_btn, clear_btn]:
301
+ if btn is not None:
302
+ if isinstance(btn, Button):
303
+ btn.render()
304
+ elif isinstance(btn, str):
305
+ btn = Button(btn, variant="secondary", size="sm")
306
+ else:
307
+ raise ValueError(
308
+ f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
309
+ )
310
+ self.buttons.append(btn) # type: ignore
311
+
312
+ with Group():
313
+ with Row():
314
+ if textbox:
315
+ textbox.container = False
316
+ textbox.show_label = False
317
+ textbox_ = textbox.render()
318
+ assert isinstance(textbox_, Textbox)
319
+ self.textbox = textbox_
320
+ else:
321
+ self.textbox = Textbox(
322
+ container=False,
323
+ show_label=False,
324
+ label="Message",
325
+ placeholder="Type a message...",
326
+ scale=7,
327
+ autofocus=autofocus,
328
+ )
329
+ if submit_btn is not None:
330
+ if isinstance(submit_btn, Button):
331
+ submit_btn.render()
332
+ elif isinstance(submit_btn, str):
333
+ submit_btn = Button(
334
+ submit_btn,
335
+ variant="primary",
336
+ scale=2,
337
+ min_width=150,
338
+ )
339
+ else:
340
+ raise ValueError(
341
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
342
+ )
343
+ if stop_btn is not None:
344
+ if isinstance(stop_btn, Button):
345
+ stop_btn.visible = False
346
+ stop_btn.render()
347
+ elif isinstance(stop_btn, str):
348
+ stop_btn = Button(
349
+ stop_btn,
350
+ variant="stop",
351
+ visible=False,
352
+ scale=2,
353
+ min_width=150,
354
+ )
355
+ else:
356
+ raise ValueError(
357
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
358
+ )
359
+ self.num_tokens = Textbox(
360
+ container=False,
361
+ label="num_tokens",
362
+ placeholder="0 tokens",
363
+ scale=1,
364
+ interactive=False,
365
+ # autofocus=autofocus,
366
+ min_width=10
367
+ )
368
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
369
+
370
+ self.fake_api_btn = Button("Fake API", visible=False)
371
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
372
+ (
373
+ self.retry_btn,
374
+ self.undo_btn,
375
+ self.clear_btn,
376
+ self.submit_btn,
377
+ self.stop_btn,
378
+ ) = self.buttons
379
+
380
+ if examples:
381
+ if self.is_generator:
382
+ examples_fn = self._examples_stream_fn
383
+ else:
384
+ examples_fn = self._examples_fn
385
+
386
+ self.examples_handler = Examples(
387
+ examples=examples,
388
+ inputs=[self.textbox] + self.additional_inputs,
389
+ outputs=self.chatbot,
390
+ fn=examples_fn,
391
+ )
392
+
393
+ any_unrendered_inputs = any(
394
+ not inp.is_rendered for inp in self.additional_inputs
395
+ )
396
+ if self.additional_inputs and any_unrendered_inputs:
397
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
398
+ if self.render_additional_inputs_fn is not None:
399
+ self.render_additional_inputs_fn()
400
+ else:
401
+ for input_component in self.additional_inputs:
402
+ if not input_component.is_rendered:
403
+ input_component.render()
404
+
405
+ self.rag_content = gr.Textbox(
406
+ scale=4,
407
+ lines=16,
408
+ label='Retrieved RAG context',
409
+ placeholder="Rag context and instrution will show up here",
410
+ interactive=False
411
+ )
412
+
413
+ # The example caching must happen after the input components have rendered
414
+ if cache_examples:
415
+ client_utils.synchronize_async(self.examples_handler.cache)
416
+
417
+ self.saved_input = State()
418
+ self.chatbot_state = (
419
+ State(self.chatbot.value) if self.chatbot.value else State([])
420
+ )
421
+
422
+ self._setup_events()
423
+ self._setup_api()
424
+
425
+ def _setup_events(self) -> None:
426
+ from gradio.components import State
427
+ has_on = False
428
+ try:
429
+ from gradio.events import Dependency, EventListenerMethod, on
430
+ has_on = True
431
+ except ImportError as ie:
432
+ has_on = False
433
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
434
+ if not self.is_generator:
435
+ raise NotImplementedError(f'should use generator')
436
+
437
+ if has_on:
438
+ # new version
439
+ submit_triggers = (
440
+ [self.textbox.submit, self.submit_btn.click]
441
+ if self.submit_btn
442
+ else [self.textbox.submit]
443
+ )
444
+ submit_event = (
445
+ on(
446
+ submit_triggers,
447
+ self._clear_and_save_textbox,
448
+ [self.textbox],
449
+ [self.textbox, self.saved_input],
450
+ api_name=False,
451
+ queue=False,
452
+ )
453
+ .then(
454
+ self._display_input,
455
+ [self.saved_input, self.chatbot_state],
456
+ [self.chatbot, self.chatbot_state],
457
+ api_name=False,
458
+ queue=False,
459
+ )
460
+ .then(
461
+ submit_fn,
462
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
463
+ [self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
464
+ api_name=False,
465
+ )
466
+ )
467
+ self._setup_stop_events(submit_triggers, submit_event)
468
+ else:
469
+ raise ValueError(f'Better install new gradio version than 3.44.0')
470
+
471
+ if self.retry_btn:
472
+ retry_event = (
473
+ self.retry_btn.click(
474
+ self._delete_prev_fn,
475
+ [self.chatbot_state],
476
+ [self.chatbot, self.saved_input, self.chatbot_state],
477
+ api_name=False,
478
+ queue=False,
479
+ )
480
+ .then(
481
+ self._display_input,
482
+ [self.saved_input, self.chatbot_state],
483
+ [self.chatbot, self.chatbot_state],
484
+ api_name=False,
485
+ queue=False,
486
+ )
487
+ .then(
488
+ submit_fn,
489
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
490
+ [self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
491
+ api_name=False,
492
+ )
493
+ )
494
+ self._setup_stop_events([self.retry_btn.click], retry_event)
495
+
496
+ if self.undo_btn:
497
+ self.undo_btn.click(
498
+ self._delete_prev_fn,
499
+ [self.chatbot_state],
500
+ [self.chatbot, self.saved_input, self.chatbot_state],
501
+ api_name=False,
502
+ queue=False,
503
+ ).then(
504
+ lambda x: x,
505
+ [self.saved_input],
506
+ [self.textbox],
507
+ api_name=False,
508
+ queue=False,
509
+ )
510
+ # Reconfigure clear_btn to stop and clear text box
511
+
512
+ async def _stream_fn(
513
+ self,
514
+ message: str,
515
+ history_with_input,
516
+ request: Request,
517
+ *args,
518
+ ) -> AsyncGenerator:
519
+ history = history_with_input[:-1]
520
+ inputs, _, _ = special_args(
521
+ self.fn, inputs=[message, history, *args], request=request
522
+ )
523
+
524
+ if self.is_async:
525
+ generator = self.fn(*inputs)
526
+ else:
527
+ generator = await anyio.to_thread.run_sync(
528
+ self.fn, *inputs, limiter=self.limiter
529
+ )
530
+ generator = SyncToAsyncIterator(generator, self.limiter)
531
+
532
+ # ! In case of error, yield the previous history & undo any generation before raising error
533
+ try:
534
+ first_response_pack = await async_iteration(generator)
535
+ if isinstance(first_response_pack, (tuple, list)):
536
+ first_response, num_tokens, rag_content = first_response_pack
537
+ else:
538
+ first_response, num_tokens, rag_content = first_response_pack, -1, ""
539
+ update = history + [[message, first_response]]
540
+ yield update, update, f"{num_tokens} toks", rag_content
541
+ except StopIteration:
542
+ update = history + [[message, None]]
543
+ yield update, update, "NaN toks", ""
544
+ except Exception as e:
545
+ yield history, history, "NaN toks", ""
546
+ raise e
547
+
548
+ try:
549
+ async for response_pack in generator:
550
+ if isinstance(response_pack, (tuple, list)):
551
+ response, num_tokens, rag_content = response_pack
552
+ else:
553
+ response, num_tokens, rag_content = response_pack, "NaN toks", ""
554
+ update = history + [[message, response]]
555
+ yield update, update, f"{num_tokens} toks", rag_content
556
+ except Exception as e:
557
+ yield history, history, "NaN toks", ""
558
+ raise e
559
+
560
+
561
+
562
+ @register_demo
563
+ class RagChatInterfaceDemo(ChatInterfaceDemo):
564
+
565
+ @property
566
+ def examples(self):
567
+ return [
568
+ ["Explain how attention works.", "assets/attention_all_you_need.pdf"],
569
+ ["Explain why the sky is blue.", None],
570
+ ]
571
+
572
+ @property
573
+ def tab_name(self):
574
+ return "RAG Chat"
575
+
576
+ def create_demo(
577
+ self,
578
+ title: str | None = None,
579
+ description: str | None = None,
580
+ **kwargs
581
+ ) -> gr.Blocks:
582
+ load_embeddings()
583
+ global RAG_EMBED
584
+ # assert RAG_EMBED is not None
585
+ print(F'{RAG_EMBED=}')
586
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
587
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
588
+ temperature = kwargs.get("temperature", TEMPERATURE)
589
+ model_name = kwargs.get("model_name", MODEL_NAME)
590
+ rag_num_docs = kwargs.get("rag_num_docs", 3)
591
+
592
+ from ..configs import RAG_EMBED_MODEL_NAME
593
+
594
+ description = description or f"""Upload a long document to ask question about it with RAG. Embedding model {RAG_EMBED_MODEL_NAME}"""
595
+
596
+ additional_inputs = [
597
+ gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt']),
598
+ gr.Number(value=temperature, label='Temperature', min_width=20),
599
+ gr.Number(value=max_tokens, label='Max tokens', min_width=20),
600
+ gr.Textbox(value=system_prompt, label='System prompt', lines=2),
601
+ gr.Number(value=rag_num_docs, label='RAG Top-K', min_width=20),
602
+ gr.Textbox(value=DOC_INSTRUCTION, label='RAG instruction'),
603
+ ]
604
+ def render_additional_inputs_fn():
605
+ additional_inputs[0].render()
606
+ with Row():
607
+ additional_inputs[1].render()
608
+ additional_inputs[2].render()
609
+ additional_inputs[4].render()
610
+ additional_inputs[3].render()
611
+ additional_inputs[5].render()
612
+
613
+ demo_chat = RagChatInterface(
614
+ chat_response_stream_multiturn_doc_engine,
615
+ chatbot=gr.Chatbot(
616
+ label=model_name,
617
+ bubble_full_width=False,
618
+ latex_delimiters=[
619
+ { "left": "$", "right": "$", "display": False},
620
+ { "left": "$$", "right": "$$", "display": True},
621
+ ],
622
+ show_copy_button=True,
623
+ ),
624
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
625
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
626
+ # ! consider preventing the stop button
627
+ # stop_btn=None,
628
+ title=title,
629
+ description=description,
630
+ additional_inputs=additional_inputs,
631
+ render_additional_inputs_fn=render_additional_inputs_fn,
632
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
633
+ examples=self.examples,
634
+ cache_examples=False,
635
+ )
636
+ return demo_chat
637
+
638
+
multipurpose_chatbot/demos/text_completion.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from gradio.themes import ThemeClass as Theme
3
+ import numpy as np
4
+ import argparse
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+ from gradio.components import Button, Component
20
+ from gradio.events import Dependency, EventListenerMethod
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+
25
+
26
+ import inspect
27
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
28
+
29
+ import anyio
30
+ from gradio_client import utils as client_utils
31
+ from gradio_client.documentation import document
32
+
33
+ from gradio.blocks import Blocks
34
+ from gradio.components import (
35
+ Button,
36
+ Chatbot,
37
+ Component,
38
+ Markdown,
39
+ State,
40
+ Textbox,
41
+ get_component_instance,
42
+ )
43
+ from gradio.events import Dependency, on
44
+ from gradio.helpers import create_examples as Examples # noqa: N812
45
+ from gradio.helpers import special_args
46
+ from gradio.layouts import Accordion, Group, Row
47
+ from gradio.routes import Request
48
+ from gradio.themes import ThemeClass as Theme
49
+ from gradio.utils import SyncToAsyncIterator, async_iteration
50
+
51
+
52
+ from .base_demo import register_demo, get_demo_class, BaseDemo
53
+
54
+
55
+ from ..configs import (
56
+ SYSTEM_PROMPT,
57
+ MODEL_NAME,
58
+ MAX_TOKENS,
59
+ TEMPERATURE,
60
+ )
61
+
62
+ from ..globals import MODEL_ENGINE
63
+
64
+
65
+ def generate_text_completion_stream_engine(
66
+ message: str,
67
+ temperature: float,
68
+ max_tokens: int,
69
+ stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
70
+ ):
71
+ global MODEL_ENGINE
72
+ temperature = float(temperature)
73
+ # ! remove frequency_penalty
74
+ # frequency_penalty = float(frequency_penalty)
75
+ max_tokens = int(max_tokens)
76
+ # message = message.strip()
77
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
78
+ stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>', '<|im_end|>']))
79
+ if message.strip() != message:
80
+ gr.Warning(f'There are preceding/trailing spaces in the message, may lead to unexpected behavior')
81
+ if len(message) == 0:
82
+ raise gr.Error("The message cannot be empty!")
83
+ num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
84
+ if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
85
+ raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
86
+
87
+ outputs = None
88
+ response = None
89
+ num_tokens = -1
90
+ for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
91
+ prompt=message,
92
+ temperature=temperature,
93
+ max_tokens=max_tokens,
94
+ stop_strings=stop_strings,
95
+ )):
96
+ if isinstance(outputs, tuple):
97
+ response, num_tokens = outputs
98
+ else:
99
+ response, num_tokens = outputs, -1
100
+ yield message + response, f"{num_tokens} tokens"
101
+
102
+ if response is not None:
103
+ yield message + response, f"{num_tokens} tokens"
104
+
105
+
106
+ @register_demo
107
+ class TextCompletionDemo(BaseDemo):
108
+ @property
109
+ def tab_name(self):
110
+ return "Text Completion"
111
+
112
+ def create_demo(
113
+ self,
114
+ title: str | None = None,
115
+ description: str | None = None,
116
+ **kwargs
117
+ ) -> gr.Blocks:
118
+ system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
119
+ max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
120
+ temperature = kwargs.get("temperature", TEMPERATURE)
121
+ model_name = kwargs.get("model_name", MODEL_NAME)
122
+ # frequence_penalty = FREQUENCE_PENALTY
123
+ # presence_penalty = PRESENCE_PENALTY
124
+ max_tokens = max_tokens // 2
125
+
126
+ description = description or f"""Put any context string (like few-shot prompts)"""
127
+
128
+ with gr.Blocks() as demo_text_completion:
129
+ if title:
130
+ gr.Markdown(title)
131
+ if description:
132
+ gr.Markdown(description)
133
+ with gr.Row():
134
+ txt = gr.Textbox(
135
+ scale=4,
136
+ lines=16,
137
+ show_label=False,
138
+ placeholder="Enter any free form text and submit",
139
+ container=False,
140
+ )
141
+ with gr.Row():
142
+ submit_button = gr.Button('Submit', variant='primary', scale=9)
143
+ stop_button = gr.Button('Stop', variant='stop', scale=9, visible=False)
144
+ num_tokens = Textbox(
145
+ container=False,
146
+ show_label=False,
147
+ label="num_tokens",
148
+ placeholder="0 tokens",
149
+ scale=1,
150
+ interactive=False,
151
+ min_width=10
152
+ )
153
+ with gr.Row():
154
+ temp_input = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
155
+ length_input = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
156
+ stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>,<|im_end|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
157
+ examples = gr.Examples(
158
+ examples=[
159
+ ["The following is the recite the declaration of independence:",]
160
+ ],
161
+ inputs=[txt, temp_input, length_input, stop_strings],
162
+ # outputs=[txt]
163
+ )
164
+ # ! Handle stop button
165
+ submit_trigger = submit_button.click
166
+ submit_event = submit_button.click(
167
+ # submit_trigger,
168
+ generate_text_completion_stream_engine,
169
+ [txt, temp_input, length_input, stop_strings],
170
+ [txt, num_tokens],
171
+ # api_name=False,
172
+ # queue=False,
173
+ )
174
+
175
+ submit_trigger(
176
+ lambda: (
177
+ Button(visible=False), Button(visible=True),
178
+ ),
179
+ None,
180
+ [submit_button, stop_button],
181
+ api_name=False,
182
+ queue=False,
183
+ )
184
+ submit_event.then(
185
+ lambda: (Button(visible=True), Button(visible=False)),
186
+ None,
187
+ [submit_button, stop_button],
188
+ api_name=False,
189
+ queue=False,
190
+ )
191
+ stop_button.click(
192
+ None,
193
+ None,
194
+ None,
195
+ cancels=submit_event,
196
+ api_name=False,
197
+ )
198
+
199
+ return demo_text_completion
multipurpose_chatbot/engines/.DS_Store ADDED
Binary file (6.15 kB). View file
 
multipurpose_chatbot/engines/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .base_engine import BaseEngine
3
+
4
+ BACKENDS = [
5
+ "mlx",
6
+ "vllm",
7
+ "transformers",
8
+ "llama_cpp",
9
+ # "llava_llama_cpp",
10
+ "debug",
11
+ "sealmmm_transformers",
12
+ ]
13
+
14
+ ENGINE_LOADED = False
15
+
16
+ def load_multipurpose_chatbot_engine(backend: str):
17
+ # ! lazy import other engines
18
+ global ENGINE_LOADED
19
+ assert backend in BACKENDS, f'{backend} not in {BACKENDS}'
20
+ if ENGINE_LOADED:
21
+ raise RuntimeError(f'{ENGINE_LOADED=} this means load_multipurpose_chatbot_engine has already been called! Check your codes.')
22
+ print(f'Load model from {backend}')
23
+ if backend == "mlx":
24
+ from .mlx_engine import MlxEngine
25
+ model_engine = MlxEngine()
26
+ elif backend == 'vllm':
27
+ from .vllm_engine import VllmEngine
28
+ model_engine = VllmEngine()
29
+ elif backend == 'transformers':
30
+ from .transformers_engine import TransformersEngine
31
+ model_engine = TransformersEngine()
32
+ elif backend == 'llama_cpp':
33
+ from .llama_cpp_engine import LlamaCppEngine
34
+ model_engine = LlamaCppEngine()
35
+ # ! llava_llama_cpp currently not done due to bugs
36
+ # elif backend == 'llava_llama_cpp':
37
+ # from .llava_llama_cpp_engine import LlavaLlamaCppEngine
38
+ # model_engine = LlavaLlamaCppEngine()
39
+ elif backend == 'debug':
40
+ from .debug_engine import DebugEngine
41
+ model_engine = DebugEngine()
42
+ elif backend == 'sealmmm_transformers':
43
+ from .sealmmm_engine import SeaLMMMv0Engine
44
+ model_engine = SeaLMMMv0Engine()
45
+ else:
46
+ raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
47
+
48
+ model_engine.load_model()
49
+ ENGINE_LOADED = True
50
+ return model_engine
51
+ # ! add more llama.cpp engine here.
52
+
53
+
multipurpose_chatbot/engines/base_engine.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from huggingface_hub import snapshot_download
4
+ # ! Avoid importing transformers
5
+ # from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
6
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
7
+ import time
8
+
9
+
10
+ class BaseEngine(object):
11
+ def __init__(self, **kwargs) -> None:
12
+ pass
13
+
14
+ @property
15
+ def max_position_embeddings(self) -> int:
16
+ return 10000
17
+
18
+ @property
19
+ def tokenizer(self):
20
+ raise NotImplementedError
21
+
22
+ def load_model(self, ):
23
+ raise NotImplementedError
24
+
25
+ def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
26
+ """
27
+ return string convo, add_special_tokens should be added later
28
+ """
29
+ bos_token = self.tokenizer.bos_token
30
+ eos_token = self.tokenizer.eos_token
31
+ if not add_special_tokens:
32
+ # prevent bos being added to string
33
+ self.tokenizer.bos_token = ""
34
+ self.tokenizer.eos_token = ""
35
+ full_prompt = self.tokenizer.apply_chat_template(
36
+ conversations, add_generation_prompt=add_generation_prompt,
37
+ tokenize=False,
38
+ )
39
+ self.tokenizer.bos_token = bos_token
40
+ self.tokenizer.eos_token = eos_token
41
+ return full_prompt
42
+
multipurpose_chatbot/engines/debug_engine.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from huggingface_hub import snapshot_download
4
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
5
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
6
+ import time
7
+
8
+ from .base_engine import BaseEngine
9
+
10
+ from ..configs import (
11
+ MODEL_PATH,
12
+ )
13
+
14
+ FAKE_MODEL_PATH = os.environ.get("FAKE_MODEL_PATH", MODEL_PATH)
15
+ FAKE_RESPONSE = "Wow that's very very cool, please try again."
16
+
17
+
18
+ class DebugEngine(BaseEngine):
19
+ """
20
+ It will always yield FAKE_RESPONSE
21
+ """
22
+
23
+ def __init__(self, **kwargs) -> None:
24
+ super().__init__(**kwargs)
25
+ self._model = None
26
+ self._tokenizer = None
27
+
28
+ @property
29
+ def tokenizer(self) -> PreTrainedTokenizer:
30
+ if self._tokenizer is None:
31
+ self._tokenizer = AutoTokenizer.from_pretrained(FAKE_MODEL_PATH, trust_remote_code=True)
32
+ return self._tokenizer
33
+
34
+ def load_model(self):
35
+ print(f"Load fake model with tokenizer: {self.tokenizer}")
36
+
37
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
38
+
39
+ num_tokens = len(self.tokenizer.encode(prompt))
40
+ response = FAKE_RESPONSE
41
+ for i in range(len(response)):
42
+ time.sleep(0.01)
43
+ yield response[:i], num_tokens
44
+
45
+ num_tokens = len(self.tokenizer.encode(prompt + response))
46
+ yield response, num_tokens
47
+
48
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
49
+ return [p + " -- Test" for p in prompts]
multipurpose_chatbot/engines/llama_cpp_engine.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+ import types
23
+
24
+ from gradio.components import Button
25
+ from gradio.events import Dependency, EventListenerMethod
26
+
27
+ import types
28
+ import sys
29
+
30
+ from .base_engine import BaseEngine
31
+
32
+ # ! Remember to use static cache
33
+
34
+ from ..configs import (
35
+ MODEL_PATH,
36
+ DEFAULT_CHAT_TEMPLATE,
37
+ N_CTX,
38
+ N_GPU_LAYERS,
39
+ )
40
+
41
+
42
+
43
+ def encode_tokenize(self, prompt: str, **kwargs):
44
+ """Mimic behavior of transformers tokenizer"""
45
+ prompt_tokens: List[int] = (
46
+ (
47
+ self.tokenize(prompt.encode("utf-8"), special=True)
48
+ if prompt != ""
49
+ else [self.token_bos()]
50
+ )
51
+ if isinstance(prompt, str)
52
+ else prompt
53
+ )
54
+ return prompt_tokens
55
+
56
+
57
+ conversations = [
58
+ {"role": "system", "content": "You are good."},
59
+ {"role": "user", "content": "Hello."},
60
+ {"role": "assistant", "content": "Hi."},
61
+ ]
62
+
63
+
64
+ class LlamaCppEngine(BaseEngine):
65
+ """
66
+ need to create an engine.tokenizer.encode(text) method
67
+ """
68
+ @property
69
+ def max_position_embeddings(self) -> int:
70
+ # raise ValueError
71
+ return self._model.context_params.n_ctx
72
+
73
+ def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
74
+ """
75
+ return string convo, add_special_tokens should be added later
76
+ remember to remove <s> if any,
77
+ """
78
+ from llama_cpp.llama_chat_format import Jinja2ChatFormatter
79
+
80
+ formatter = Jinja2ChatFormatter(
81
+ template=self._model.metadata['tokenizer.chat_template'],
82
+ # bos_token=self._model._model.token_get_text(self._model.token_bos()),
83
+ bos_token="",
84
+ eos_token=self._model._model.token_get_text(self._model.token_eos()),
85
+ add_generation_prompt=add_generation_prompt,
86
+ )
87
+
88
+ full_prompt = formatter(messages=conversations).prompt
89
+ # ! it may has bos
90
+ return full_prompt
91
+
92
+ @property
93
+ def tokenizer(self):
94
+ return self._model
95
+
96
+ def load_model(self):
97
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
98
+
99
+ from llama_cpp import Llama
100
+ self.model_path = MODEL_PATH
101
+ self._model = Llama(
102
+ model_path=self.model_path,
103
+ n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
104
+ # seed=1337, # Uncomment to set a specific seed
105
+ n_ctx=N_CTX, # Uncomment to increase the context window
106
+ )
107
+ self._tokenizer = self._model
108
+ self._model.encode = types.MethodType(encode_tokenize, self._model)
109
+ print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
110
+
111
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
112
+ stop_strings = list(stop_strings) if stop_strings is not None else []
113
+ stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
114
+ generator = self._model(
115
+ prompt,
116
+ max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
117
+ temperature=temperature,
118
+ stop=stop_strings, # Stop generating just before the model would generate a new question
119
+ stream=True,
120
+ )
121
+ response = ""
122
+ num_tokens = len(self.tokenizer.encode(prompt))
123
+ for g in generator:
124
+ response += g['choices'][0]['text']
125
+ yield response, num_tokens
126
+
127
+ if response is not None and len(response) > 0:
128
+ num_tokens = len(self.tokenizer.encode(prompt + response))
129
+ yield response, num_tokens
130
+
131
+
multipurpose_chatbot/engines/llava_llama_cpp_engine.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+ import types
23
+
24
+ from gradio.components import Button
25
+ from gradio.events import Dependency, EventListenerMethod
26
+
27
+ import types
28
+ import sys
29
+
30
+ from .base_engine import BaseEngine
31
+
32
+ # ! Remember to use static cache
33
+
34
+ from ..configs import (
35
+ MODEL_PATH,
36
+ DEFAULT_CHAT_TEMPLATE,
37
+ N_CTX,
38
+ N_GPU_LAYERS,
39
+ IMAGE_TOKEN,
40
+ IMAGE_TOKEN_INTERACTIVE,
41
+ IMAGE_TOKEN_LENGTH,
42
+ MAX_PACHES,
43
+ )
44
+
45
+ from .llama_cpp_engine import (
46
+ encode_tokenize,
47
+ LlamaCppEngine,
48
+ )
49
+
50
+
51
+
52
+ # resource: https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models
53
+
54
+ import base64
55
+
56
+ def image_to_base64_data_uri(file_path):
57
+ with open(file_path, "rb") as img_file:
58
+ base64_data = base64.b64encode(img_file.read()).decode('utf-8')
59
+ return f"data:image/png;base64,{base64_data}"
60
+
61
+
62
+ # file_path = 'file_path.png'
63
+ # data_uri = image_to_base64_data_uri(file_path)
64
+
65
+ # data_uri = image_to_base64_data_uri(file_path)
66
+
67
+ # messages = [
68
+ # {"role": "system", "content": "You are an assistant who perfectly describes images."},
69
+ # {
70
+ # "role": "user",
71
+ # "content": [
72
+ # {"type": "image_url", "image_url": {"url": data_uri }},
73
+ # {"type" : "text", "text": "Describe this image in detail please."}
74
+ # ]
75
+ # }
76
+ # ]
77
+
78
+
79
+ def llava_15_chat_handler_call(
80
+ self,
81
+ *,
82
+ llama: Any,
83
+ # messages: List[Any],
84
+ prompt: Union[str, List[int]],
85
+ image_data_uris: Optional[List[Any]] = None,
86
+ image_token: str = None,
87
+ functions: Optional[List[Any]] = None,
88
+ function_call: Optional[Any] = None,
89
+ tools: Optional[List[Any]] = None,
90
+ tool_choice: Optional[Any] = None,
91
+ temperature: float = 0.2,
92
+ top_p: float = 0.95,
93
+ top_k: int = 40,
94
+ min_p: float = 0.05,
95
+ typical_p: float = 1.0,
96
+ stream: bool = False,
97
+ stop: Optional[Union[str, List[str]]] = [],
98
+ response_format: Optional[
99
+ Any
100
+ ] = None,
101
+ max_tokens: Optional[int] = None,
102
+ presence_penalty: float = 0.0,
103
+ frequency_penalty: float = 0.0,
104
+ repeat_penalty: float = 1.1,
105
+ tfs_z: float = 1.0,
106
+ mirostat_mode: int = 0,
107
+ mirostat_tau: float = 5.0,
108
+ mirostat_eta: float = 0.1,
109
+ model: Optional[str] = None,
110
+ logits_processor: Optional[Any] = None,
111
+ grammar: Optional[Any] = None,
112
+ **kwargs, # type: ignore
113
+ ):
114
+ from llama_cpp.llama_chat_format import (
115
+ ctypes,
116
+ suppress_stdout_stderr,
117
+ )
118
+ assert (
119
+ llama.context_params.logits_all is True
120
+ ) # BUG: logits_all=True is required for llava
121
+ assert self.clip_ctx is not None
122
+ # ! split prompt into different parts
123
+ assert image_token is not None
124
+ prompt_parts = prompt.split(image_token)
125
+ # assert len(prompt_parts)
126
+ assert len(prompt_parts) == len(image_data_uris) + 1, f'invalid {len(prompt_parts)=} != {len(image_data_uris)=}'
127
+ llama.reset()
128
+ prefix = prompt_parts[0]
129
+ remaining_texts = prompt_parts[1:]
130
+ llama.reset()
131
+ llama.eval(llama.tokenize(prefix.encode("utf8"), add_bos=True))
132
+ for index, (image_uri, prompt_p) in enumerate(zip(image_data_uris, remaining_texts)):
133
+ image_bytes = self.load_image(image_uri)
134
+ import array
135
+ data_array = array.array("B", image_bytes)
136
+ c_ubyte_ptr = (
137
+ ctypes.c_ubyte * len(data_array)
138
+ ).from_buffer(data_array)
139
+ with suppress_stdout_stderr(disable=self.verbose):
140
+ embed = (
141
+ self._llava_cpp.llava_image_embed_make_with_bytes(
142
+ self.clip_ctx,
143
+ llama.context_params.n_threads,
144
+ c_ubyte_ptr,
145
+ len(image_bytes),
146
+ )
147
+ )
148
+ try:
149
+ n_past = ctypes.c_int(llama.n_tokens)
150
+ n_past_p = ctypes.pointer(n_past)
151
+ with suppress_stdout_stderr(disable=self.verbose):
152
+ self._llava_cpp.llava_eval_image_embed(
153
+ llama.ctx,
154
+ embed,
155
+ llama.n_batch,
156
+ n_past_p,
157
+ )
158
+ assert llama.n_ctx() >= n_past.value
159
+ llama.n_tokens = n_past.value
160
+ finally:
161
+ with suppress_stdout_stderr(disable=self.verbose):
162
+ self._llava_cpp.llava_image_embed_free(embed)
163
+
164
+ llama.eval(llama.tokenize(prompt_p.encode("utf8"), add_bos=False))
165
+ assert llama.n_ctx() >= llama.n_tokens
166
+
167
+ prompt = llama.input_ids[: llama.n_tokens].tolist()
168
+ # from llava-1.5
169
+ return llama.create_completion(
170
+ prompt=prompt,
171
+ temperature=temperature,
172
+ top_p=top_p,
173
+ top_k=top_k,
174
+ min_p=min_p,
175
+ typical_p=typical_p,
176
+ stream=stream,
177
+ stop=stop,
178
+ max_tokens=max_tokens,
179
+ presence_penalty=presence_penalty,
180
+ frequency_penalty=frequency_penalty,
181
+ repeat_penalty=repeat_penalty,
182
+ tfs_z=tfs_z,
183
+ mirostat_mode=mirostat_mode,
184
+ mirostat_tau=mirostat_tau,
185
+ mirostat_eta=mirostat_eta,
186
+ model=model,
187
+ logits_processor=logits_processor,
188
+ grammar=grammar,
189
+ )
190
+
191
+
192
+
193
+ class LlavaLlamaCppEngine(LlamaCppEngine):
194
+ """
195
+ Still in development, expect BUGS
196
+
197
+ ERROR: could not know why
198
+ objc[61055]: Class GGMLMetalClass is implemented in both miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllama.dylib (0x12cb40290) and miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllava.dylib (0x12d9c8290). One of the two will be used. Which one is undefined.
199
+
200
+ """
201
+ @property
202
+ def image_token(self):
203
+ return IMAGE_TOKEN
204
+
205
+ def get_multimodal_tokens(self, full_prompt, image_paths=None):
206
+ num_tokens = len(self.tokenizer.encode(full_prompt))
207
+ for image_path in image_paths:
208
+ num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
209
+ return num_tokens
210
+
211
+ def load_model(self):
212
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
213
+ from llama_cpp import Llama
214
+ from llama_cpp.llama_chat_format import Llava15ChatHandler
215
+ model_dir = os.path.dirname(MODEL_PATH)
216
+ self.chat_handler = Llava15ChatHandler(clip_model_path=os.path.join(model_dir, "mmproj.bin"))
217
+
218
+ self.chat_handler.__call__ = types.MethodType(llava_15_chat_handler_call, self.chat_handler)
219
+
220
+ self.model_path = MODEL_PATH
221
+ self._model = Llama(
222
+ model_path=self.model_path,
223
+ n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
224
+ # seed=1337, # Uncomment to set a specific seed
225
+ chat_handler=self.chat_handler,
226
+ n_ctx=N_CTX, # Uncomment to increase the context window
227
+ logits_all=True, # needed to make llava work
228
+ )
229
+ self._tokenizer = self._model
230
+ self._model.encode = types.MethodType(encode_tokenize, self._model)
231
+ print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
232
+
233
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
234
+ image_paths = kwargs.get("image_paths", [])
235
+
236
+ image_data_uris = [
237
+ image_to_base64_data_uri(ip)
238
+ for ip in image_paths
239
+ ]
240
+
241
+ stop_strings = list(stop_strings) if stop_strings is not None else []
242
+ stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
243
+ # generator = self._model(
244
+ generator = self.chat_handler(
245
+ prompt=prompt,
246
+ image_data_uris=image_data_uris,
247
+ image_token=self.image_token,
248
+ max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
249
+ temperature=temperature,
250
+ stop=stop_strings, # Stop generating just before the model would generate a new question
251
+ stream=True,
252
+ )
253
+ response = ""
254
+ num_tokens = len(self.tokenizer.encode(prompt))
255
+ for g in generator:
256
+ response += g['choices'][0]['text']
257
+ yield response, num_tokens
258
+
259
+ if response is not None and len(response) > 0:
260
+ num_tokens = len(self.tokenizer.encode(prompt + response))
261
+ yield response, num_tokens
262
+
263
+
264
+ """
265
+
266
+ export MODEL_PATH
267
+ BACKEND=llama_cpp
268
+ MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/seallms/SeaLLMs/SeaLLM-7B-v2-gguf/seallm-v2.chatml.Q4_K_M.gguf
269
+ N_CTX=4096
270
+ python app.py
271
+
272
+
273
+ export BACKEND=llava_llama_cpp
274
+ export MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/llava/llava-1.5/ggml-model-q4_k.gguf
275
+ export N_CTX=4096
276
+ export IMAGE_TOKEN="<image>"
277
+ python app.py
278
+
279
+
280
+ """
multipurpose_chatbot/engines/mlx_engine.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from huggingface_hub import snapshot_download
6
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
7
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
8
+ import time
9
+ from mlx_lm import load, generate
10
+ from mlx_lm.utils import generate_step
11
+
12
+ from .base_engine import BaseEngine
13
+
14
+ from ..configs import (
15
+ MODEL_PATH,
16
+ )
17
+
18
+ def generate_string(
19
+ model: nn.Module,
20
+ tokenizer: PreTrainedTokenizer,
21
+ prompt: str,
22
+ temp: float = 0.0,
23
+ max_tokens: int = 100,
24
+ verbose: bool = False,
25
+ formatter: Callable = None,
26
+ repetition_penalty: Optional[float] = None,
27
+ repetition_context_size: Optional[int] = None,
28
+ stop_strings: Optional[Tuple[str]] = None
29
+ ):
30
+ prompt_tokens = mx.array(tokenizer.encode(prompt))
31
+ stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
32
+ assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
33
+
34
+ tic = time.perf_counter()
35
+ tokens = []
36
+ skip = 0
37
+ REPLACEMENT_CHAR = "\ufffd"
38
+
39
+ for (token, prob), n in zip(
40
+ generate_step(
41
+ prompt_tokens,
42
+ model,
43
+ temp,
44
+ repetition_penalty,
45
+ repetition_context_size,
46
+ ),
47
+ range(max_tokens),
48
+ ):
49
+ if token == tokenizer.eos_token_id:
50
+ break
51
+ if n == 0:
52
+ prompt_time = time.perf_counter() - tic
53
+ tic = time.perf_counter()
54
+ tokens.append(token.item())
55
+ if stop_strings is not None:
56
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
57
+ if token_string.strip().endswith(stop_strings):
58
+ break
59
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
60
+ return token_string
61
+
62
+
63
+
64
+ def generate_yield_string(
65
+ model: nn.Module,
66
+ tokenizer: PreTrainedTokenizer,
67
+ prompt: str,
68
+ temp: float = 0.0,
69
+ max_tokens: int = 100,
70
+ verbose: bool = False,
71
+ formatter: Callable = None,
72
+ repetition_penalty: Optional[float] = None,
73
+ repetition_context_size: Optional[int] = None,
74
+ stop_strings: Optional[Tuple[str]] = None
75
+ ):
76
+ """
77
+ Generate text from the model.
78
+ Args:
79
+ model (nn.Module): The language model.
80
+ tokenizer (PreTrainedTokenizer): The tokenizer.
81
+ prompt (str): The string prompt.
82
+ temp (float): The temperature for sampling (default 0).
83
+ max_tokens (int): The maximum number of tokens (default 100).
84
+ verbose (bool): If ``True``, print tokens and timing information
85
+ (default ``False``).
86
+ formatter (Optional[Callable]): A function which takes a token and a
87
+ probability and displays it.
88
+ repetition_penalty (float, optional): The penalty factor for repeating tokens.
89
+ repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
90
+ """
91
+ if verbose:
92
+ print("=" * 10)
93
+ print("Prompt:", prompt)
94
+ stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
95
+ assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
96
+ prompt_tokens = mx.array(tokenizer.encode(prompt))
97
+ tic = time.perf_counter()
98
+ tokens = []
99
+ skip = 0
100
+ REPLACEMENT_CHAR = "\ufffd"
101
+ for (token, prob), n in zip(
102
+ generate_step(
103
+ prompt_tokens,
104
+ model,
105
+ temp,
106
+ repetition_penalty,
107
+ repetition_context_size,
108
+ ),
109
+ range(max_tokens),
110
+ ):
111
+ if token == tokenizer.eos_token_id:
112
+ break
113
+ # if n == 0:
114
+ # prompt_time = time.perf_counter() - tic
115
+ # tic = time.perf_counter()
116
+ tokens.append(token.item())
117
+ # if verbose:
118
+ # s = tokenizer.decode(tokens)
119
+ # if formatter:
120
+ # formatter(s[skip:], prob.item())
121
+ # skip = len(s)
122
+ # elif REPLACEMENT_CHAR not in s:
123
+ # print(s[skip:], end="", flush=True)
124
+ # skip = len(s)
125
+ token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
126
+ yield token_string
127
+ if stop_strings is not None and token_string.strip().endswith(stop_strings):
128
+ break
129
+
130
+ # token_count = len(tokens)
131
+ # token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
132
+
133
+ # if verbose:
134
+ # print(token_string[skip:], flush=True)
135
+ # gen_time = time.perf_counter() - tic
136
+ # print("=" * 10)
137
+ # if token_count == 0:
138
+ # print("No tokens generated for this prompt")
139
+ # return
140
+ # prompt_tps = prompt_tokens.size / prompt_time
141
+ # gen_tps = (token_count - 1) / gen_time
142
+ # print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
143
+ # print(f"Generation: {gen_tps:.3f} tokens-per-sec")
144
+
145
+ # return token_string
146
+
147
+
148
+ class MlxEngine(BaseEngine):
149
+
150
+ def __init__(self, **kwargs) -> None:
151
+ super().__init__(**kwargs)
152
+ self._model = None
153
+ self._tokenizer = None
154
+
155
+ @property
156
+ def tokenizer(self) -> PreTrainedTokenizer:
157
+ return self._tokenizer
158
+
159
+ def load_model(self, ):
160
+ model_path = MODEL_PATH
161
+ self._model, self._tokenizer = load(model_path)
162
+ self.model_path = model_path
163
+ print(f'Load MLX model from {model_path}')
164
+
165
+
166
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
167
+ num_tokens = len(self.tokenizer.encode(prompt))
168
+ response = None
169
+ for response in generate_yield_string(
170
+ self._model, self._tokenizer,
171
+ prompt, temp=temperature, max_tokens=max_tokens,
172
+ repetition_penalty=kwargs.get("repetition_penalty", None),
173
+ stop_strings=stop_strings,
174
+ ):
175
+ yield response, num_tokens
176
+ if response is not None:
177
+ full_text = prompt + response
178
+ num_tokens = len(self.tokenizer.encode(full_text))
179
+ yield response, num_tokens
180
+
181
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
182
+ """
183
+ ! MLX does not support
184
+ """
185
+ responses = [
186
+ generate_string(
187
+ self._model, self._tokenizer,
188
+ s, temp=temperature, max_tokens=max_tokens,
189
+ repetition_penalty=kwargs.get("repetition_penalty", None),
190
+ stop_strings=stop_strings,
191
+ )
192
+ for s in prompts
193
+ ]
194
+ return responses
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+
multipurpose_chatbot/engines/modeling_sealmm.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+
9
+ from transformers import PreTrainedModel
10
+ from transformers.activations import ACT2FN
11
+ from transformers.cache_utils import Cache
12
+ from transformers.modeling_outputs import ModelOutput
13
+ from transformers.models.clip.configuration_clip import CLIPConfig
14
+ from transformers.utils import (
15
+ add_start_docstrings,
16
+ add_start_docstrings_to_model_forward,
17
+ logging,
18
+ replace_return_docstrings,
19
+ )
20
+ from transformers import AutoModel, AutoModelForCausalLM
21
+ from transformers.models.llava.configuration_llava import LlavaConfig
22
+
23
+ from transformers.models.llava.modeling_llava import (
24
+ LlavaCausalLMOutputWithPast,
25
+ LlavaMultiModalProjector,
26
+ LlavaPreTrainedModel,
27
+ LLAVA_START_DOCSTRING,
28
+ LLAVA_INPUTS_DOCSTRING,
29
+ LlavaForConditionalGeneration,
30
+ )
31
+
32
+ from transformers.models.blip_2.configuration_blip_2 import (
33
+ Blip2Config,
34
+ Blip2QFormerConfig,
35
+ )
36
+ import os
37
+ from transformers.models.blip_2.modeling_blip_2 import (
38
+ Blip2Config,
39
+ Blip2QFormerModel,
40
+ Blip2PreTrainedModel,
41
+ BLIP_2_INPUTS_DOCSTRING,
42
+ )
43
+
44
+ from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10
45
+
46
+ # from .configuration_sealmm import SeaLMMConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ # _CONFIG_FOR_DOC = "LlavaConfig"
51
+ _CONFIG_FOR_DOC = "SeaLMMConfig"
52
+
53
+
54
+ class SeaLMMConfig(LlavaConfig):
55
+ def __init__(self, *args, **kwargs):
56
+ self.projector_num_layers = kwargs.get("projector_num_layers", 1)
57
+ super().__init__(*args, **kwargs)
58
+
59
+ """
60
+ Llava
61
+
62
+ vision_config.num_hidden_layers = vision_config.num_hidden_layers + config.vision_feature_layer + 1
63
+ # "num_hidden_layers": 24,
64
+
65
+ """
66
+
67
+ IMAGE_TOKEN = "<|image|>"
68
+ DEBUG = bool(int(os.environ.get("DEBUG", "0")))
69
+
70
+
71
+ def by_sample_merge_input_ids_with_image_features(
72
+ self, image_features, inputs_embeds, input_ids, attention_mask=None, position_ids=None
73
+ ):
74
+ """
75
+ input_ids: [tlen]
76
+ input_embeds: [tlen, dt]
77
+ img_embeds: [ilen, ifeat, di]
78
+
79
+ e.g:
80
+ input_ids: [
81
+ a b c d e f X g h i j k X l m
82
+ ]
83
+ img_embeds: [3, ifeat, id] # img_embeds has padding
84
+ """
85
+ num_images, num_image_patches, embed_dim = image_features.shape
86
+ sequence_length = input_ids.size(0)
87
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
88
+ assert not left_padding, f'should only use right padding'
89
+ # 1. Create a mask to know where special image tokens are
90
+ special_image_token_mask = input_ids == self.config.image_token_index
91
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
92
+ # Compute the maximum embed dimension
93
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
94
+
95
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
96
+ from transformers.models.clip.modeling_clip import (
97
+ contrastive_loss,
98
+ clip_loss,
99
+ CLIPVisionModelOutput,
100
+ CLIPTextModelOutput,
101
+ CLIPOutput,
102
+ CLIPTextEmbeddings,
103
+ CLIPVisionEmbeddings,
104
+ CLIPAttention,
105
+ CLIPMLP,
106
+ CLIPEncoderLayer,
107
+ CLIPPreTrainedModel,
108
+ CLIPTextTransformer,
109
+ CLIPTextModel,
110
+ CLIPVisionTransformer,
111
+ CLIPVisionModel,
112
+ CLIPModel,
113
+ CLIPEncoder,
114
+ CLIPTextModelWithProjection,
115
+ CLIPVisionModelWithProjection,
116
+ CLIP_START_DOCSTRING,
117
+ CLIP_TEXT_INPUTS_DOCSTRING,
118
+ CLIP_VISION_INPUTS_DOCSTRING,
119
+ CLIP_INPUTS_DOCSTRING,
120
+ )
121
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
122
+
123
+
124
+
125
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
126
+ def _get_unpad_data(attention_mask):
127
+ import torch.nn.functional as F
128
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
129
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
130
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
131
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
132
+ return (
133
+ indices,
134
+ cu_seqlens,
135
+ max_seqlen_in_batch,
136
+ )
137
+
138
+ class CLIPFlashAttention2(CLIPAttention):
139
+ """
140
+ CLIP flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
141
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
142
+ flash attention and deal with padding tokens in case the input contains any of them.
143
+ """
144
+ def __init__(self, config, is_causal=False):
145
+ super().__init__(config)
146
+ self.is_causal = is_causal
147
+
148
+ def forward(
149
+ self,
150
+ hidden_states: torch.Tensor,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ causal_attention_mask: Optional[torch.Tensor] = None,
153
+ output_attentions: Optional[bool] = False,
154
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
155
+ """Input shape: Batch x Time x Channel"""
156
+ if output_attentions:
157
+ raise ValueError("CLIPFlashAttention2 does not support output_attentions")
158
+
159
+ if self.is_causal and causal_attention_mask is None:
160
+ raise ValueError("CLIPFlashAttention2 has causal=True but no causal_attention_mask provided")
161
+
162
+ bsz, tgt_len, embed_dim = hidden_states.size()
163
+
164
+ # [batch_size, tgt_len, embed_dim]
165
+ query_states = self.q_proj(hidden_states)
166
+ key_states = self.k_proj(hidden_states)
167
+ value_states = self.v_proj(hidden_states)
168
+
169
+ # [batch_size, tgt_len, embed_dim] -> [batch_size, tgt_len, num_heads, head_dim]
170
+ query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
171
+ key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
172
+ value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
173
+
174
+ attn_output = self._flash_attention_forward(
175
+ query_states=query_states,
176
+ key_states=key_states,
177
+ value_states=value_states,
178
+ attention_mask=attention_mask,
179
+ query_length=tgt_len,
180
+ dropout=self.dropout,
181
+ softmax_scale=self.scale,
182
+ )
183
+ # [batch_size, tgt_len, num_heads, head_dim] -> [batch_size, tgt_len, embed_dim]
184
+ attn_output = attn_output.view(bsz, tgt_len, embed_dim)
185
+ attn_output = self.out_proj(attn_output)
186
+
187
+ return attn_output, None
188
+
189
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
190
+ def _flash_attention_forward(
191
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
192
+ ) -> torch.Tensor:
193
+ """
194
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
195
+ first unpad the input, then computes the attention scores and pad the final attention scores.
196
+
197
+ Args:
198
+ query_states (`torch.Tensor`):
199
+ Input query states to be passed to Flash Attention API
200
+ key_states (`torch.Tensor`):
201
+ Input key states to be passed to Flash Attention API
202
+ value_states (`torch.Tensor`):
203
+ Input value states to be passed to Flash Attention API
204
+ attention_mask (`torch.Tensor`):
205
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
206
+ position of padding tokens and 1 for the position of non-padding tokens.
207
+ dropout (`int`, *optional*):
208
+ Attention dropout
209
+ softmax_scale (`float`, *optional*):
210
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
211
+ """
212
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
213
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
214
+ # Contains at least one padding token in the sequence
215
+ if attention_mask is not None:
216
+ batch_size = query_states.shape[0]
217
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
218
+ query_states, key_states, value_states, attention_mask, query_length
219
+ )
220
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
221
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
222
+
223
+ attn_output_unpad = flash_attn_varlen_func(
224
+ query_states,
225
+ key_states,
226
+ value_states,
227
+ cu_seqlens_q=cu_seqlens_q,
228
+ cu_seqlens_k=cu_seqlens_k,
229
+ max_seqlen_q=max_seqlen_in_batch_q,
230
+ max_seqlen_k=max_seqlen_in_batch_k,
231
+ dropout_p=dropout,
232
+ softmax_scale=softmax_scale,
233
+ causal=self.is_causal,
234
+ )
235
+
236
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
237
+ else:
238
+ attn_output = flash_attn_func(
239
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
240
+ )
241
+
242
+ return attn_output
243
+
244
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
245
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
246
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
247
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
248
+
249
+ key_layer = index_first_axis(
250
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
251
+ )
252
+ value_layer = index_first_axis(
253
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
254
+ )
255
+ if query_length == kv_seq_len:
256
+ query_layer = index_first_axis(
257
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
258
+ )
259
+ cu_seqlens_q = cu_seqlens_k
260
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
261
+ indices_q = indices_k
262
+ elif query_length == 1:
263
+ max_seqlen_in_batch_q = 1
264
+ # There is a memcpy here, that is very bad.
265
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
266
+ indices_q = cu_seqlens_q[:-1]
267
+ query_layer = query_layer.squeeze(1)
268
+ else:
269
+ # The :q_len slice assumes right padding.
270
+ attention_mask = attention_mask[:, :query_length]
271
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
272
+
273
+ return (
274
+ query_layer,
275
+ key_layer,
276
+ value_layer,
277
+ indices_q,
278
+ (cu_seqlens_q, cu_seqlens_k),
279
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
280
+ )
281
+
282
+
283
+ class SeaLMMCLIPEncoderLayer(CLIPEncoderLayer):
284
+ def __init__(self, config: CLIPConfig):
285
+ super(CLIPEncoderLayer, self).__init__()
286
+ self.embed_dim = config.hidden_size
287
+ # self.self_attn = LlavaCLIPFlashAttention(config)
288
+ if is_flash_attn_greater_or_equal_2_10():
289
+ self.self_attn = CLIPFlashAttention2(config)
290
+ else:
291
+ self.self_attn = CLIPAttention(config)
292
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
293
+ self.mlp = CLIPMLP(config)
294
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
295
+
296
+
297
+ class SeaLMMCLIPEncoder(CLIPEncoder):
298
+ """
299
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
300
+ [`CLIPEncoderLayer`].
301
+
302
+ Args:
303
+ config: CLIPConfig
304
+ """
305
+
306
+ def __init__(self, config: CLIPConfig):
307
+ super(CLIPEncoder, self).__init__()
308
+ self.config = config
309
+ self.layers = nn.ModuleList([SeaLMMCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
310
+ self.gradient_checkpointing = False
311
+
312
+ def forward(
313
+ self,
314
+ inputs_embeds,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ causal_attention_mask: Optional[torch.Tensor] = None,
317
+ output_attentions: Optional[bool] = None,
318
+ output_hidden_states: Optional[bool] = None,
319
+ return_dict: Optional[bool] = None,
320
+ ) -> Union[Tuple, BaseModelOutput]:
321
+ r"""
322
+ Args:
323
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
324
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
325
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
326
+ than the model's internal embedding lookup matrix.
327
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
328
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
329
+
330
+ - 1 for tokens that are **not masked**,
331
+ - 0 for tokens that are **masked**.
332
+
333
+ [What are attention masks?](../glossary#attention-mask)
334
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
335
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
336
+
337
+ - 1 for tokens that are **not masked**,
338
+ - 0 for tokens that are **masked**.
339
+
340
+ [What are attention masks?](../glossary#attention-mask)
341
+ output_attentions (`bool`, *optional*):
342
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
343
+ returned tensors for more detail.
344
+ output_hidden_states (`bool`, *optional*):
345
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
346
+ for more detail.
347
+ return_dict (`bool`, *optional*):
348
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
349
+ """
350
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
351
+ output_hidden_states = (
352
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
353
+ )
354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
355
+ output_hidden_states = False
356
+ output_attentions = False
357
+ # return_dict = False
358
+
359
+ encoder_states = () if output_hidden_states else None
360
+ all_attentions = () if output_attentions else None
361
+
362
+ hidden_states = inputs_embeds
363
+ for idx, encoder_layer in enumerate(self.layers):
364
+ if output_hidden_states:
365
+ encoder_states = encoder_states + (hidden_states,)
366
+ # if self.gradient_checkpointing and self.training:
367
+ # layer_outputs = self._gradient_checkpointing_func(
368
+ # encoder_layer.__call__,
369
+ # hidden_states,
370
+ # attention_mask,
371
+ # causal_attention_mask,
372
+ # output_attentions,
373
+ # )
374
+ # else:
375
+ # ! enforce no checkpointing here
376
+ layer_outputs = encoder_layer(
377
+ hidden_states,
378
+ attention_mask,
379
+ causal_attention_mask,
380
+ output_attentions=output_attentions,
381
+ )
382
+
383
+ hidden_states = layer_outputs[0]
384
+
385
+ if output_attentions:
386
+ all_attentions = all_attentions + (layer_outputs[1],)
387
+
388
+ if output_hidden_states:
389
+ encoder_states = encoder_states + (hidden_states,)
390
+
391
+ if not return_dict:
392
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
393
+ return BaseModelOutput(
394
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
395
+ )
396
+
397
+
398
+ class SeaLMMVisionTransformer(nn.Module):
399
+ def __init__(self, config: CLIPVisionConfig):
400
+ super().__init__()
401
+ self.config = config
402
+ embed_dim = config.hidden_size
403
+
404
+ self.embeddings = CLIPVisionEmbeddings(config)
405
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
406
+ # self.encoder = CLIPEncoder(config)
407
+ self.encoder = SeaLMMCLIPEncoder(config)
408
+ # self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
409
+
410
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
411
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
412
+ def forward(
413
+ self,
414
+ pixel_values: Optional[torch.FloatTensor] = None,
415
+ output_attentions: Optional[bool] = None,
416
+ output_hidden_states: Optional[bool] = None,
417
+ return_dict: Optional[bool] = None,
418
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
419
+ r"""
420
+ Returns:
421
+
422
+ """
423
+ assert output_attentions is None
424
+ assert output_hidden_states is None
425
+ # assert return_dict is None
426
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
427
+ output_hidden_states = (
428
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
429
+ )
430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
431
+
432
+ if pixel_values is None:
433
+ raise ValueError("You have to specify pixel_values")
434
+
435
+ hidden_states = self.embeddings(pixel_values)
436
+ hidden_states = self.pre_layrnorm(hidden_states)
437
+
438
+ encoder_outputs = self.encoder(
439
+ inputs_embeds=hidden_states,
440
+ output_attentions=output_attentions,
441
+ output_hidden_states=output_hidden_states,
442
+ return_dict=return_dict,
443
+ )
444
+
445
+ last_hidden_state = encoder_outputs[0]
446
+
447
+ if not return_dict:
448
+ raise ValueError(f'Not support return_dict')
449
+
450
+ return BaseModelOutputWithPooling(
451
+ last_hidden_state=last_hidden_state,
452
+ # pooler_output=pooled_output,
453
+ pooler_output=None,
454
+ hidden_states=encoder_outputs.hidden_states,
455
+ attentions=encoder_outputs.attentions,
456
+ )
457
+
458
+
459
+ @add_start_docstrings(
460
+ """The vision model from CLIP without any head or projection on top.""",
461
+ CLIP_START_DOCSTRING,
462
+ )
463
+ class SeaLMMCLIPVisionModel(CLIPPreTrainedModel):
464
+ config_class = CLIPVisionConfig
465
+ main_input_name = "pixel_values"
466
+ _no_split_modules = ["SeaLMMCLIPEncoderLayer"]
467
+
468
+ def __init__(self, config: CLIPVisionConfig):
469
+ super().__init__(config)
470
+ self.vision_model = SeaLMMVisionTransformer(config)
471
+ # Initialize weights and apply final processing
472
+ self.post_init()
473
+
474
+ def get_input_embeddings(self) -> nn.Module:
475
+ return self.vision_model.embeddings.patch_embedding
476
+
477
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
478
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
479
+ def forward(
480
+ self,
481
+ pixel_values: Optional[torch.FloatTensor] = None,
482
+ output_attentions: Optional[bool] = None,
483
+ output_hidden_states: Optional[bool] = None,
484
+ return_dict: Optional[bool] = None,
485
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
486
+ r"""
487
+ Returns:
488
+
489
+ Examples:
490
+
491
+ ```python
492
+ >>> from PIL import Image
493
+ >>> import requests
494
+ >>> from transformers import AutoProcessor, CLIPVisionModel
495
+
496
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
497
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
498
+
499
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
500
+ >>> image = Image.open(requests.get(url, stream=True).raw)
501
+
502
+ >>> inputs = processor(images=image, return_tensors="pt")
503
+
504
+ >>> outputs = model(**inputs)
505
+ >>> last_hidden_state = outputs.last_hidden_state
506
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
507
+ ```"""
508
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
509
+
510
+ return self.vision_model(
511
+ pixel_values=pixel_values,
512
+ output_attentions=output_attentions,
513
+ output_hidden_states=output_hidden_states,
514
+ return_dict=return_dict,
515
+ )
516
+
517
+
518
+ class SeaLMMMultiModalProjector(SeaLMMCLIPEncoder):
519
+ def __init__(self, config: SeaLMMConfig):
520
+ super(CLIPEncoder, self).__init__()
521
+ self.config = config
522
+ self.projector_num_layers = getattr(config, "projector_num_layers", 2)
523
+ self.vision_config = config.vision_config
524
+ self.num_vision_feature_layer = int(0 - config.vision_feature_layer) - 1
525
+
526
+ assert self.num_vision_feature_layer > 0
527
+
528
+ self.layers = nn.ModuleList([
529
+ # LlavaCLIPFasterEncoderLayer(self.vision_config)
530
+ SeaLMMCLIPEncoderLayer(self.vision_config)
531
+ for _ in range(self.projector_num_layers)]
532
+ )
533
+
534
+ projector_layernorm_eps = getattr(config, "projector_layernorm_eps", 1e-05)
535
+ self.projector_layernorm = nn.LayerNorm(
536
+ # len(config.vision_feature_layers) * config.vision_config.hidden_size, eps=projector_layernorm_eps
537
+ config.vision_config.hidden_size, eps=projector_layernorm_eps
538
+ )
539
+
540
+ self.linear_1 = nn.Linear(
541
+ # len(config.vision_feature_layers) * config.vision_config.hidden_size,
542
+ config.vision_config.hidden_size,
543
+ config.text_config.hidden_size,
544
+ bias=True,
545
+ )
546
+ # self.act = ACT2FN[config.projector_hidden_act]
547
+ # self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
548
+
549
+ self.gradient_checkpointing = False
550
+
551
+ def forward(self, hidden_states, attention_mask=None, causal_attention_mask=None):
552
+ """
553
+ hidden_states must not be striped
554
+ """
555
+ output_attentions = False
556
+
557
+ for idx, encoder_layer in enumerate(self.layers):
558
+ # if output_hidden_states:
559
+ # encoder_states = encoder_states + (hidden_states,)
560
+ # if self.gradient_checkpointing and self.training:
561
+ # layer_outputs = self._gradient_checkpointing_func(
562
+ # encoder_layer.__call__,
563
+ # hidden_states,
564
+ # attention_mask,
565
+ # causal_attention_mask,
566
+ # output_attentions,
567
+ # )
568
+ # else:
569
+ # ! turn off checkpointing
570
+ layer_outputs = encoder_layer(
571
+ hidden_states,
572
+ attention_mask,
573
+ causal_attention_mask,
574
+ output_attentions=output_attentions,
575
+ )
576
+
577
+ hidden_states = layer_outputs[0]
578
+
579
+ hidden_states = hidden_states[:, 1:]
580
+
581
+ hidden_states = self.projector_layernorm(hidden_states)
582
+ hidden_states = self.linear_1(hidden_states)
583
+ # hidden_states = self.act(hidden_states)
584
+ # hidden_states = self.linear_2(hidden_states)
585
+ return hidden_states
586
+
587
+
588
+
589
+ @add_start_docstrings(
590
+ """The CLip- LLAVA model which consists of a vision backbone and a language model.""",
591
+ LLAVA_START_DOCSTRING,
592
+ )
593
+ class SeaLMMForCausalLM(LlavaPreTrainedModel):
594
+ def __init__(self, config: SeaLMMConfig, vision_tower=None, language_model=None):
595
+ super().__init__(config)
596
+ # self.vision_tower = AutoModel.from_config(config.vision_config)
597
+ # self.vision_tower = vision_tower or LlavaCLIPVisionModel(config=config.vision_config)
598
+ self.vision_tower = vision_tower or SeaLMMCLIPVisionModel(config=config.vision_config)
599
+ self.multi_modal_projector = SeaLMMMultiModalProjector(config)
600
+ self.vocab_size = config.vocab_size
601
+ self.language_model = language_model or AutoModelForCausalLM.from_config(
602
+ config.text_config, attn_implementation=config._attn_implementation
603
+ )
604
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
605
+ self.post_init()
606
+
607
+ self.freeze_vision_tower = True
608
+
609
+ def unfreeze_vision_tower(self):
610
+ logger.info(f'UNFREEZE {self.freeze_vision_tower=}')
611
+ self.freeze_vision_tower = False
612
+
613
+ def freeze_vision_tower(self):
614
+ logger.info(f'FREEZE {self.freeze_vision_tower=}')
615
+ self.freeze_vision_tower = True
616
+
617
+ @classmethod
618
+ def create_model_config_from_components(
619
+ cls,
620
+ lm_config=None,
621
+ vision_config=None,
622
+ tokenizer=None,
623
+ vision_feature_layer=None,
624
+ projector_num_layers=1,
625
+ **kwargs,
626
+ ) -> SeaLMMConfig:
627
+ # self.projector_num_layers = kwargs.get("projector_num_layers", 1)
628
+ config = SeaLMMConfig(vision_config, lm_config, projector_num_layers=projector_num_layers, **kwargs)
629
+ config.vision_feature_layer = config.vision_feature_layer if vision_feature_layer is None else vision_feature_layer
630
+
631
+ if config.vision_feature_layer < 0:
632
+ config.vision_config.num_hidden_layers = config.vision_config.num_hidden_layers + config.vision_feature_layer + 1
633
+ else:
634
+ config.vision_config.num_hidden_layers = config.vision_feature_layer + 1
635
+
636
+ if IMAGE_TOKEN not in tokenizer.get_vocab():
637
+ tokenizer.add_special_tokens({"cls_token": IMAGE_TOKEN})
638
+
639
+ config.image_token_index = tokenizer.cls_token_id
640
+ config.vocab_size = config.text_config.vocab_size
641
+ config.architectures = ["SeaLMMForCausalLM"]
642
+ return config
643
+
644
+ def get_input_embeddings(self):
645
+ return self.language_model.get_input_embeddings()
646
+
647
+ def set_input_embeddings(self, value):
648
+ self.language_model.set_input_embeddings(value)
649
+
650
+ def get_output_embeddings(self):
651
+ return self.language_model.get_output_embeddings()
652
+
653
+ def set_output_embeddings(self, new_embeddings):
654
+ self.language_model.set_output_embeddings(new_embeddings)
655
+
656
+ def set_decoder(self, decoder):
657
+ self.language_model.set_decoder(decoder)
658
+
659
+ def get_decoder(self):
660
+ return self.language_model.get_decoder()
661
+
662
+ def tie_weights(self):
663
+ return self.language_model.tie_weights()
664
+
665
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
666
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
667
+ # update vocab size
668
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
669
+ self.config.vocab_size = model_embeds.num_embeddings
670
+ self.vocab_size = model_embeds.num_embeddings
671
+ return model_embeds
672
+
673
+ # @torch.no_grad
674
+ def _merge_input_ids_with_image_features(
675
+ self, image_features, inputs_embeds, input_ids, attention_mask, position_ids, labels=None
676
+ ):
677
+ """
678
+ input_ids: [b, tlen]
679
+ input_embeds: [b, tlen, dt]
680
+ image_features: [b, ilen, ifeat, di]
681
+ labels: None or [b, tlen] --> must extend labels to input_ids,
682
+
683
+ # in input_ids, there may be image_token_index, number of image_token_index <= ilen
684
+ input_ids: [
685
+ a b c d e f X g h i j k X l m
686
+ o p q r X s t u v _ _ _ _ _ _
687
+ ]
688
+ input_ids should be: [
689
+ a b c d e f X X X X X g h i j k X X X X X l m
690
+ o p q r X X X X X s t u v _ _ _ _ _ _ _ _ _ _
691
+ ]
692
+ labels should be: [
693
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ _ _ l m
694
+ o p q r _ _ _ _ _ s t u v _ _ _ _ _ _ _ _ _ _
695
+ ]
696
+ # mask replace image onto it
697
+
698
+ # Use torch.vmap for simplicy
699
+ def sample_merge():
700
+ input_ids: [tlen]
701
+ input_embeds: [tlen, dt]
702
+ img_embeds: [ilen, ifeat, di]
703
+ e.g:
704
+ input_ids: [
705
+ a b c d e f X g h i j k X l m
706
+ ]
707
+ img_embeds: [3, ifeat, id] # img_embeds has padding
708
+
709
+
710
+ """
711
+ with torch.no_grad():
712
+ num_images, num_image_patches, embed_dim = image_features.shape
713
+ batch_size, sequence_length = input_ids.shape
714
+ # left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
715
+ left_padding = torch.any(attention_mask[:, 0] == 0)
716
+ # assert not left_padding or batch_size == 1
717
+ # 1. Create a mask to know where special image tokens are
718
+ special_image_token_mask = input_ids == self.config.image_token_index
719
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
720
+ # Reserve for padding of num_images
721
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
722
+ assert total_num_special_image_tokens == num_images, f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}'
723
+ # Compute the maximum embed dimension
724
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
725
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
726
+
727
+ # 2. Compute the positions where text should be written
728
+ # Calculate new positions for text tokens in merged image-text sequence.
729
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
730
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
731
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
732
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
733
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
734
+ if left_padding:
735
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
736
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
737
+
738
+ # 3. Create the full embedding, already padded to the maximum position
739
+ final_embedding = torch.zeros(
740
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
741
+ )
742
+ final_attention_mask = torch.zeros(
743
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
744
+ )
745
+ final_labels = None
746
+ if labels is not None:
747
+ final_labels = torch.full_like(final_attention_mask, self.config.ignore_index).to(torch.long)
748
+
749
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
750
+ # set the corresponding tensors into their correct target device.
751
+ target_device = inputs_embeds.device
752
+ batch_indices, non_image_indices, text_to_overwrite = (
753
+ batch_indices.to(target_device),
754
+ non_image_indices.to(target_device),
755
+ text_to_overwrite.to(target_device),
756
+ )
757
+ attention_mask = attention_mask.to(target_device)
758
+
759
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
760
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
761
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
762
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
763
+ if labels is not None:
764
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
765
+
766
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
767
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
768
+ # image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
769
+ if left_padding:
770
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
771
+ else:
772
+ val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < new_token_positions[:, -1:].to(target_device)
773
+ image_to_overwrite &= val
774
+
775
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
776
+ raise ValueError(
777
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
778
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
779
+ )
780
+
781
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
782
+ final_attention_mask |= image_to_overwrite
783
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
784
+
785
+ if not left_padding:
786
+ # Making sure its the same
787
+ seq_lens = final_attention_mask.sum(-1)
788
+ for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)):
789
+ # seq_len = mask.sum(-1)
790
+ assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}'
791
+ assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}'
792
+
793
+
794
+ # if DEBUG:
795
+ # print(f'final_attention_mask=\n{final_attention_mask.tolist()}')
796
+ # print(f'text_to_overwrite=\n{text_to_overwrite.int().tolist()}')
797
+ # print(f'image_to_overwrite=\n{image_to_overwrite.int().tolist()}')
798
+ # print(f'position_ids=\n{position_ids.tolist()}')
799
+ # print(f'labels=\n{labels.tolist()}')
800
+ # print(f'final_labels=\n{final_labels.tolist()}')
801
+
802
+ return final_embedding, final_attention_mask, position_ids, final_labels
803
+
804
+ def extract_image_features(self, pixel_values, vision_feature_select_strategy=None):
805
+ vision_feature_select_strategy = (
806
+ vision_feature_select_strategy
807
+ if vision_feature_select_strategy is not None
808
+ else self.config.vision_feature_select_strategy
809
+ )
810
+ with (torch.no_grad() if self.freeze_vision_tower else nullcontext()):
811
+ image_outputs = self.vision_tower(pixel_values)
812
+ hiddent_states = image_outputs.last_hidden_state
813
+ image_features = self.multi_modal_projector(hiddent_states)
814
+ return image_features
815
+
816
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
817
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
818
+ def forward(
819
+ self,
820
+ input_ids: torch.LongTensor = None,
821
+ pixel_values: torch.FloatTensor = None,
822
+ attention_mask: Optional[torch.Tensor] = None,
823
+ position_ids: Optional[torch.LongTensor] = None,
824
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
825
+ inputs_embeds: Optional[torch.FloatTensor] = None,
826
+ vision_feature_layer: Optional[int] = None,
827
+ vision_feature_select_strategy: Optional[str] = None,
828
+ labels: Optional[torch.LongTensor] = None,
829
+ use_cache: Optional[bool] = None,
830
+ output_attentions: Optional[bool] = None,
831
+ output_hidden_states: Optional[bool] = None,
832
+ return_dict: Optional[bool] = None,
833
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
834
+ r"""
835
+ Args:
836
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
837
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
838
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
839
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
840
+
841
+ Returns:
842
+
843
+ Example:
844
+
845
+ ```python
846
+ >>> from PIL import Image
847
+ >>> import requests
848
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
849
+
850
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
851
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
852
+
853
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
854
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
855
+ >>> image = Image.open(requests.get(url, stream=True).raw)
856
+
857
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
858
+
859
+ >>> # Generate
860
+ >>> generate_ids = model.generate(**inputs, max_length=30)
861
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
862
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
863
+ ```"""
864
+
865
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
866
+ output_hidden_states = (
867
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
868
+ )
869
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
870
+ vision_feature_layer = (
871
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
872
+ )
873
+ vision_feature_select_strategy = (
874
+ vision_feature_select_strategy
875
+ if vision_feature_select_strategy is not None
876
+ else self.config.vision_feature_select_strategy
877
+ )
878
+
879
+ if inputs_embeds is None:
880
+ # 1. Extra the input embeddings
881
+ for_inputs_embeds_ids = input_ids.clone()
882
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
883
+ # inputs_embeds = self.get_input_embeddings()(input_ids)
884
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
885
+
886
+ # 2. Merge text and images
887
+ if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
888
+ num_images = pixel_values.size(0)
889
+ batch_size, sequence_length = input_ids.shape
890
+ special_image_token_mask = input_ids == self.config.image_token_index
891
+ # Reserve for padding of num_images
892
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
893
+ assert num_images == total_num_special_image_tokens, (
894
+ f'{num_images} < {total_num_special_image_tokens} | {special_image_token_mask}'
895
+ )
896
+ # pixel_values = pixel_values[:total_num_special_image_tokens]
897
+
898
+ # image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
899
+ # with (torch.no_grad() if self.freeze_vision_tower else nullcontext()):
900
+ # image_outputs = self.vision_tower(pixel_values)
901
+ # # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
902
+ # # selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
903
+ # selected_image_feature = image_outputs.last_hidden_state
904
+
905
+ # if vision_feature_select_strategy == "default":
906
+ # selected_image_feature = selected_image_feature[:, 1:]
907
+ # elif vision_feature_select_strategy == "full":
908
+ # selected_image_feature = selected_image_feature
909
+ # else:
910
+ # raise ValueError(
911
+ # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
912
+ # )
913
+
914
+ # image_features = self.multi_modal_projector(selected_image_feature)
915
+ # print(f"{pixel_values.size()=}")
916
+ # ! extract_image_features will handle all image features extraction
917
+ image_features = self.extract_image_features(pixel_values)
918
+ # if DEBUG:
919
+ # image_features = image_features[:, :3]
920
+
921
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
922
+ image_features, inputs_embeds, input_ids, attention_mask, position_ids,
923
+ labels=labels
924
+ )
925
+ # if labels is None:
926
+ # # ! this is wrong!
927
+ # labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
928
+ # print(inputs_embeds.size())
929
+
930
+ elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
931
+ # there is no images
932
+ pass
933
+ else:
934
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
935
+ # generation with cache
936
+ # ! (phi) why do we need to do this?
937
+ # if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
938
+ # # ! it can possible the bug because if mistral, from the first layer_key like this
939
+ # # ! MUST UNDERSTAND and fix error
940
+ # # Retrieve the first layer to inspect the logits and mask out the hidden states
941
+ # # that are set to 0
942
+ # first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
943
+ # batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
944
+ # # Get the target length
945
+ # target_seqlen = first_layer_past_key_value.shape[-1] + 1
946
+
947
+ # extended_attention_mask = torch.ones(
948
+ # (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
949
+ # dtype=attention_mask.dtype,
950
+ # device=attention_mask.device,
951
+ # )
952
+ # # print(f'{extended_attention_mask.shape} | {batch_index=} | {non_attended_tokens=}')
953
+
954
+ # # Zero-out the places where we don't need to attend
955
+ # extended_attention_mask[batch_index, non_attended_tokens] = 0
956
+
957
+ # attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
958
+ # position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
959
+
960
+ # ! fix: https://github.com/huggingface/transformers/blob/c90268de7560c3fef21a927e0bfcf2b611a8711e/src/transformers/models/llava/modeling_llava.py
961
+ # https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
962
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
963
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
964
+ # that are set to 0
965
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
966
+
967
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
968
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
969
+
970
+ # Get the target length
971
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
972
+
973
+ extended_attention_mask = torch.ones(
974
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
975
+ dtype=attention_mask.dtype,
976
+ device=attention_mask.device,
977
+ )
978
+
979
+ # Filter out only the tokens that can be un-attended, this can happen
980
+ # in the case one uses Llava + Fused modules where the cache on the
981
+ # first iteration is already big enough, or if one passes custom cache
982
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
983
+ new_batch_index = batch_index[valid_indices]
984
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
985
+
986
+ # Zero-out the places where we don't need to attend
987
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
988
+
989
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
990
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
991
+
992
+
993
+ outputs = self.language_model(
994
+ attention_mask=attention_mask,
995
+ position_ids=position_ids,
996
+ past_key_values=past_key_values,
997
+ inputs_embeds=inputs_embeds,
998
+ use_cache=use_cache,
999
+ output_attentions=output_attentions,
1000
+ output_hidden_states=output_hidden_states,
1001
+ return_dict=return_dict,
1002
+ )
1003
+
1004
+ logits = outputs[0]
1005
+
1006
+ loss = None
1007
+ if labels is not None:
1008
+ # Shift so that tokens < n predict n
1009
+ if attention_mask is not None:
1010
+ shift_attention_mask = attention_mask[..., 1:]
1011
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
1012
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
1013
+ else:
1014
+ shift_logits = logits[..., :-1, :].contiguous()
1015
+ shift_labels = labels[..., 1:].contiguous()
1016
+ # Flatten the tokens
1017
+ loss_fct = nn.CrossEntropyLoss()
1018
+ loss = loss_fct(
1019
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
1020
+ )
1021
+
1022
+ if not return_dict:
1023
+ output = (logits,) + outputs[1:]
1024
+ return (loss,) + output if loss is not None else output
1025
+
1026
+ return LlavaCausalLMOutputWithPast(
1027
+ loss=loss,
1028
+ logits=logits,
1029
+ past_key_values=outputs.past_key_values,
1030
+ hidden_states=outputs.hidden_states,
1031
+ attentions=outputs.attentions,
1032
+ )
1033
+
1034
+ def prepare_inputs_for_generation(
1035
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
1036
+ ):
1037
+ if past_key_values is not None:
1038
+ if isinstance(past_key_values, Cache):
1039
+ cache_length = past_key_values.get_seq_length()
1040
+ past_length = past_key_values.seen_tokens
1041
+ else:
1042
+ cache_length = past_length = past_key_values[0][0].shape[2]
1043
+
1044
+ # Keep only the unprocessed tokens:
1045
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1046
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1047
+ # input)
1048
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1049
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1050
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1051
+ # input_ids based on the past_length.
1052
+ elif past_length < input_ids.shape[1]:
1053
+ input_ids = input_ids[:, past_length:]
1054
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1055
+ elif self.config.image_token_index in input_ids:
1056
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
1057
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1058
+ # older attention values, as their corresponding values are not part of the input.
1059
+ if cache_length < past_length and attention_mask is not None:
1060
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1061
+
1062
+ position_ids = kwargs.get("position_ids", None)
1063
+ if attention_mask is not None and position_ids is None:
1064
+ # create position_ids on the fly for batch generation
1065
+ position_ids = attention_mask.long().cumsum(-1) - 1
1066
+ position_ids.masked_fill_(attention_mask == 0, 1)
1067
+ if past_key_values:
1068
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1069
+
1070
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1071
+ if inputs_embeds is not None and past_key_values is None:
1072
+ model_inputs = {"inputs_embeds": inputs_embeds}
1073
+ else:
1074
+ model_inputs = {"input_ids": input_ids}
1075
+
1076
+ model_inputs.update(
1077
+ {
1078
+ "position_ids": position_ids,
1079
+ "past_key_values": past_key_values,
1080
+ "use_cache": kwargs.get("use_cache"),
1081
+ "attention_mask": attention_mask,
1082
+ "pixel_values": pixel_values,
1083
+ }
1084
+ )
1085
+ return model_inputs
1086
+
1087
+ def _reorder_cache(self, *args, **kwargs):
1088
+ return self.language_model._reorder_cache(*args, **kwargs)
1089
+
1090
+
1091
+
multipurpose_chatbot/engines/sealmmm_engine.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers_stream_generator import init_stream_support
2
+ # init_stream_support()
3
+
4
+ import os
5
+ import numpy as np
6
+ import argparse
7
+ import torch
8
+ import gradio as gr
9
+ from typing import Any, Iterator
10
+ from typing import Iterator, List, Optional, Tuple
11
+ import filelock
12
+ import glob
13
+ import json
14
+ import time
15
+ from gradio.routes import Request
16
+ from gradio.utils import SyncToAsyncIterator, async_iteration
17
+ from gradio.helpers import special_args
18
+ import anyio
19
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
20
+
21
+ from gradio_client.documentation import document, set_documentation_group
22
+
23
+ from typing import List, Optional, Union, Dict, Tuple
24
+ from tqdm.auto import tqdm
25
+ from huggingface_hub import snapshot_download
26
+
27
+ from gradio.components import Button
28
+ from gradio.events import Dependency, EventListenerMethod
29
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
30
+ import types
31
+ import sys
32
+ from .base_engine import BaseEngine
33
+ from .transformers_engine import TransformersEngine, NewGenerationMixin
34
+
35
+ from ..configs import (
36
+ STREAM_CHECK_MULTIPLE,
37
+ STREAM_YIELD_MULTIPLE,
38
+ )
39
+
40
+ CODE_PATH = os.environ.get("CODE_PATH", "")
41
+ MODEL_PATH = os.environ.get("MODEL_PATH", "")
42
+
43
+ IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]"
44
+
45
+ IMAGE_LENGTH = 576
46
+ MAX_PACHES = 1
47
+
48
+
49
+ BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
50
+ BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
51
+ LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
52
+ KEYWORDS = os.environ.get("KEYWORDS", "").strip()
53
+ KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
54
+ KEYWORDS = [x.lower() for x in KEYWORDS]
55
+
56
+ LANG_BLOCK_MESSAGE = """Unsupported language."""
57
+
58
+ KEYWORD_BLOCK_MESSAGE = "Invalid request."
59
+
60
+
61
+ def _detect_lang(text):
62
+ # Disable language that may have safety risk
63
+ from langdetect import detect as detect_lang
64
+ dlang = None
65
+ try:
66
+ dlang = detect_lang(text)
67
+ except Exception as e:
68
+ if "No features in text." in str(e):
69
+ return "en"
70
+ else:
71
+ return "zh"
72
+ return dlang
73
+
74
+
75
+ def block_lang(
76
+ message: str,
77
+ history: List[Tuple[str, str]] = None,
78
+ ) -> str:
79
+ # relieve history base block
80
+ if len(BLOCK_LANGS) == 0:
81
+ return False
82
+
83
+ if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
84
+ return True
85
+ else:
86
+ _lang = _detect_lang(message)
87
+ if _lang in BLOCK_LANGS:
88
+ # print(f'Detect blocked {_lang}: {message}')
89
+ return True
90
+ else:
91
+ return False
92
+
93
+ def safety_check(text, history=None, ) -> Optional[str]:
94
+ """
95
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
96
+ This provides an additional security measure to enhance safety and compliance with local regulations.
97
+ """
98
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
99
+ return KEYWORD_BLOCK_MESSAGE
100
+
101
+ if len(BLOCK_LANGS) > 0:
102
+ if block_lang(text, history):
103
+ return LANG_BLOCK_MESSAGE
104
+
105
+ return None
106
+
107
+
108
+ def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
109
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
110
+ return KEYWORD_BLOCK_MESSAGE
111
+ if len(BLOCK_LANGS) > 0:
112
+ import re
113
+ delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
114
+ turns = re.split(r"|".join(delimiter), text)
115
+ turns = [t for t in turns if t.strip() != '']
116
+ for t in turns:
117
+ if block_lang(t):
118
+ return LANG_BLOCK_MESSAGE
119
+ return None
120
+
121
+
122
+ def is_check_safety():
123
+ return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0
124
+
125
+
126
+ def safety_check_conversation(conversation) -> Optional[str]:
127
+ """
128
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
129
+ This provides an additional security measure to enhance safety and compliance with local regulations.
130
+ """
131
+ texts = [c['content'] for c in conversation]
132
+ for text in texts:
133
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
134
+ return KEYWORD_BLOCK_MESSAGE
135
+
136
+ if len(BLOCK_LANGS) > 0:
137
+ if block_lang(text):
138
+ return LANG_BLOCK_MESSAGE
139
+ return None
140
+
141
+
142
+ class SeaLMMMv0Engine(TransformersEngine):
143
+
144
+ @property
145
+ def image_token(self):
146
+ return IMAGE_TOKEN
147
+
148
+ @property
149
+ def max_position_embeddings(self) -> int:
150
+ return self._model.config.max_position_embeddings
151
+
152
+ @property
153
+ def tokenizer(self):
154
+ return self._tokenizer
155
+
156
+ @property
157
+ def processor(self):
158
+ return self._processor
159
+
160
+ def load_model(self):
161
+ from transformers import AutoProcessor
162
+ import sys
163
+ # caution: path[0] is reserved for script path (or '' in REPL)
164
+ # sys.path.append(CODE_PATH)
165
+
166
+ # from examples.llm.src.models.sealmm.modeling_sealmm import (
167
+ # SeaLMMForCausalLM
168
+ # )
169
+ from modeling_sealmm import (SeaLMMForCausalLM, )
170
+ model_path = MODEL_PATH
171
+ print(f'Loading model from {model_path}')
172
+
173
+ print(f'model_path={model_path}')
174
+ if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
175
+ os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
176
+
177
+ self._processor = AutoProcessor.from_pretrained(model_path)
178
+ self._model = SeaLMMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
179
+
180
+ self._model.sample_old = self._model.sample
181
+ self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
182
+
183
+ self._tokenizer = self._processor.tokenizer
184
+ print(self._model)
185
+ print(f"{self.max_position_embeddings=}")
186
+
187
+ def get_multimodal_tokens(self, full_prompt, image_paths=None):
188
+ num_tokens = len(self.tokenizer.encode(full_prompt))
189
+ for image_path in image_paths:
190
+ num_tokens += IMAGE_LENGTH * MAX_PACHES
191
+ return num_tokens
192
+
193
+ def maybe_raise_safety(self, message, gen_index=-1):
194
+ if is_check_safety():
195
+ if gen_index < 0:
196
+ message_safety = safety_check_conversation_string(message)
197
+ if message_safety is not None:
198
+ raise gr.Error(message_safety)
199
+ else:
200
+ if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
201
+ message_safety = safety_check_conversation_string(message)
202
+ if message_safety is not None:
203
+ raise gr.Error(message_safety)
204
+
205
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
206
+ from transformers.generation.utils import GenerationConfig
207
+ from PIL import Image
208
+ image_paths = kwargs.get("image_paths", None)
209
+ image_paths = image_paths or []
210
+
211
+ images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
212
+
213
+ with torch.no_grad():
214
+ inputs = self.processor(prompt, images, return_tensors='pt')
215
+ # inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None}
216
+ inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None}
217
+ num_tokens = self.get_multimodal_tokens(prompt, image_paths)
218
+ # non-streaming generation
219
+ # output = self._model.generate(
220
+ # **inputs,
221
+ # do_sample=True,
222
+ # temperature=temperature,
223
+ # max_new_tokens=max_tokens,
224
+ # pad_token_id=self.processor.tokenizer.pad_token_id,
225
+ # )
226
+ # # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
227
+ # full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
228
+ # response = full_output_text.split("<|im_start|>assistant\n")[-1]
229
+ # num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
230
+ # print(prompt)
231
+ # print(response)
232
+ # print(num_tokens)
233
+ # yield response, num_tokens
234
+
235
+ # if i % 4 == 0 and i > 1:
236
+ # message_safety = safety_check(response)
237
+ # if message_safety is not None:
238
+ # history = undo_history(history)
239
+ # yield history, "", None
240
+ # raise gr.Error(message_safety)
241
+ self.maybe_raise_safety(prompt)
242
+
243
+ # # ! streaming
244
+ generator = self._model.generate(
245
+ **inputs,
246
+ do_sample=True,
247
+ temperature=temperature,
248
+ max_new_tokens=max_tokens,
249
+ pad_token_id=self.processor.tokenizer.pad_token_id,
250
+ )
251
+
252
+ out_tokens = []
253
+ response = None
254
+ for index, token in enumerate(generator):
255
+ out_tokens.append(token.item())
256
+ response = self.processor.tokenizer.decode(out_tokens)
257
+
258
+ self.maybe_raise_safety(response, gen_index=index)
259
+ yield response, num_tokens
260
+
261
+ del generator
262
+
263
+ if response is not None:
264
+ self.maybe_raise_safety(prompt)
265
+
266
+ full_text = prompt + response
267
+ num_tokens = self.get_multimodal_tokens(full_text, image_paths)
268
+ yield response, num_tokens
269
+
multipurpose_chatbot/engines/transformers_engine.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import argparse
5
+ import torch
6
+ import gradio as gr
7
+ from typing import Any, Iterator
8
+ from typing import Iterator, List, Optional, Tuple
9
+ import filelock
10
+ import glob
11
+ import json
12
+ import time
13
+ from gradio.routes import Request
14
+ from gradio.utils import SyncToAsyncIterator, async_iteration
15
+ from gradio.helpers import special_args
16
+ import anyio
17
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
18
+
19
+ from gradio_client.documentation import document, set_documentation_group
20
+
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+ from tqdm.auto import tqdm
23
+ from huggingface_hub import snapshot_download
24
+ import types
25
+
26
+ from gradio.components import Button
27
+ from gradio.events import Dependency, EventListenerMethod
28
+
29
+ from .base_engine import BaseEngine
30
+
31
+ # ! Remember to use static cache
32
+
33
+ from transformers import (
34
+ GenerationConfig,
35
+ GenerationMixin,
36
+ LogitsProcessorList,
37
+ StoppingCriteriaList,
38
+ DisjunctiveConstraint,
39
+ BeamSearchScorer,
40
+ PhrasalConstraint,
41
+ ConstrainedBeamSearchScorer,
42
+ PreTrainedModel,
43
+ )
44
+ import numpy as np
45
+ import random
46
+ import warnings
47
+ import inspect
48
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
49
+ import torch
50
+ from typing import Callable, List, Optional, Union
51
+ from torch import nn
52
+ import torch.distributed as dist
53
+ import copy
54
+
55
+ from ..configs import (
56
+ MODEL_PATH,
57
+ DTYPE,
58
+ DEVICE,
59
+ )
60
+
61
+
62
+ def setup_seed(seed):
63
+ if seed == -1:
64
+ return
65
+ torch.manual_seed(seed)
66
+ if torch.cuda.is_available():
67
+ torch.cuda.manual_seed_all(seed)
68
+ np.random.seed(seed)
69
+ random.seed(seed)
70
+ torch.backends.cudnn.deterministic = True
71
+
72
+
73
+ class NewGenerationMixin(GenerationMixin):
74
+ """
75
+ Allow generator sampling
76
+
77
+ """
78
+
79
+ # ! Copy from transformers.generation.utils -> GenerationMixin
80
+ # Change sample function to sample_stream
81
+ @torch.no_grad()
82
+ def sample_stream(
83
+ self,
84
+ input_ids: torch.LongTensor,
85
+ logits_processor: Optional[LogitsProcessorList] = None,
86
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
87
+ logits_warper: Optional[LogitsProcessorList] = None,
88
+ max_length: Optional[int] = None,
89
+ pad_token_id: Optional[int] = None,
90
+ eos_token_id: Optional[Union[int, List[int]]] = None,
91
+ output_attentions: Optional[bool] = None,
92
+ output_hidden_states: Optional[bool] = None,
93
+ output_scores: Optional[bool] = None,
94
+ output_logits: Optional[bool] = None,
95
+ return_dict_in_generate: Optional[bool] = None,
96
+ synced_gpus: bool = False,
97
+ streamer: Optional["BaseStreamer"] = None,
98
+ **model_kwargs,
99
+ ):
100
+ r"""
101
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
102
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
103
+
104
+ <Tip warning={true}>
105
+
106
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
107
+ For an overview of generation strategies and code examples, check the [following
108
+ guide](../generation_strategies).
109
+
110
+ </Tip>
111
+
112
+ Parameters:
113
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
114
+ The sequence used as a prompt for the generation.
115
+ logits_processor (`LogitsProcessorList`, *optional*):
116
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
117
+ used to modify the prediction scores of the language modeling head applied at each generation step.
118
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
119
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
120
+ used to tell if the generation loop should stop.
121
+ logits_warper (`LogitsProcessorList`, *optional*):
122
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
123
+ to warp the prediction score distribution of the language modeling head applied before multinomial
124
+ sampling at each generation step.
125
+ max_length (`int`, *optional*, defaults to 20):
126
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
127
+ tokens. The maximum length of the sequence to be generated.
128
+ pad_token_id (`int`, *optional*):
129
+ The id of the *padding* token.
130
+ eos_token_id (`Union[int, List[int]]`, *optional*):
131
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
132
+ output_attentions (`bool`, *optional*, defaults to `False`):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
134
+ returned tensors for more details.
135
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
137
+ for more details.
138
+ output_scores (`bool`, *optional*, defaults to `False`):
139
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
140
+ output_logits (`bool`, *optional*, defaults to `False`):
141
+ Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
142
+ more details.
143
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
144
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
145
+ synced_gpus (`bool`, *optional*, defaults to `False`):
146
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
147
+ streamer (`BaseStreamer`, *optional*):
148
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
149
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
150
+ model_kwargs:
151
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
152
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
153
+
154
+ Return:
155
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
156
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
157
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
158
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
159
+ `model.config.is_encoder_decoder=True`.
160
+
161
+ Examples:
162
+
163
+ ```python
164
+ >>> from transformers import (
165
+ ... AutoTokenizer,
166
+ ... AutoModelForCausalLM,
167
+ ... LogitsProcessorList,
168
+ ... MinLengthLogitsProcessor,
169
+ ... TopKLogitsWarper,
170
+ ... TemperatureLogitsWarper,
171
+ ... StoppingCriteriaList,
172
+ ... MaxLengthCriteria,
173
+ ... )
174
+ >>> import torch
175
+
176
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
177
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
178
+
179
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
180
+ >>> model.config.pad_token_id = model.config.eos_token_id
181
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
182
+
183
+ >>> input_prompt = "Today is a beautiful day, and"
184
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
185
+
186
+ >>> # instantiate logits processors
187
+ >>> logits_processor = LogitsProcessorList(
188
+ ... [
189
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
190
+ ... ]
191
+ ... )
192
+ >>> # instantiate logits processors
193
+ >>> logits_warper = LogitsProcessorList(
194
+ ... [
195
+ ... TopKLogitsWarper(50),
196
+ ... TemperatureLogitsWarper(0.7),
197
+ ... ]
198
+ ... )
199
+
200
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
201
+
202
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
203
+ >>> outputs = model.sample(
204
+ ... input_ids,
205
+ ... logits_processor=logits_processor,
206
+ ... logits_warper=logits_warper,
207
+ ... stopping_criteria=stopping_criteria,
208
+ ... )
209
+
210
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
211
+ ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
212
+ ```"""
213
+ # init values
214
+ from transformers.generation.utils import (
215
+ validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
216
+ )
217
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
218
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
219
+ if max_length is not None:
220
+ warnings.warn(
221
+ "`max_length` is deprecated in this function, use"
222
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
223
+ UserWarning,
224
+ )
225
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
226
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
227
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
228
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
229
+ if isinstance(eos_token_id, int):
230
+ eos_token_id = [eos_token_id]
231
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
232
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
233
+ output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
234
+ output_attentions = (
235
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
236
+ )
237
+ output_hidden_states = (
238
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
239
+ )
240
+ return_dict_in_generate = (
241
+ return_dict_in_generate
242
+ if return_dict_in_generate is not None
243
+ else self.generation_config.return_dict_in_generate
244
+ )
245
+
246
+ # init attention / hidden states / scores tuples
247
+ scores = () if (return_dict_in_generate and output_scores) else None
248
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
249
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
250
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
251
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
252
+
253
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
254
+ if return_dict_in_generate and self.config.is_encoder_decoder:
255
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
256
+ encoder_hidden_states = (
257
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
258
+ )
259
+
260
+ # keep track of which sequences are already finished
261
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
262
+
263
+ this_peer_finished = False # used by synced_gpus only
264
+ # auto-regressive generation
265
+ while True:
266
+ if synced_gpus:
267
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
268
+ # The following logic allows an early break if all peers finished generating their sequence
269
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
270
+ # send 0.0 if we finished, 1.0 otherwise
271
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
272
+ # did all peers finish? the reduced sum will be 0.0 then
273
+ if this_peer_finished_flag.item() == 0.0:
274
+ break
275
+
276
+ # prepare model inputs
277
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
278
+
279
+ # forward pass to get next token
280
+ outputs = self(
281
+ **model_inputs,
282
+ return_dict=True,
283
+ output_attentions=output_attentions,
284
+ output_hidden_states=output_hidden_states,
285
+ )
286
+
287
+ if synced_gpus and this_peer_finished:
288
+ continue # don't waste resources running the code we don't need
289
+
290
+ next_token_logits = outputs.logits[:, -1, :]
291
+
292
+ # pre-process distribution
293
+ next_token_scores = logits_processor(input_ids, next_token_logits)
294
+ next_token_scores = logits_warper(input_ids, next_token_scores)
295
+
296
+ # Store scores, attentions and hidden_states when required
297
+ if return_dict_in_generate:
298
+ if output_scores:
299
+ scores += (next_token_scores,)
300
+ if output_logits:
301
+ raw_logits += (next_token_logits,)
302
+ if output_attentions:
303
+ decoder_attentions += (
304
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
305
+ )
306
+ if self.config.is_encoder_decoder:
307
+ cross_attentions += (outputs.cross_attentions,)
308
+
309
+ if output_hidden_states:
310
+ decoder_hidden_states += (
311
+ (outputs.decoder_hidden_states,)
312
+ if self.config.is_encoder_decoder
313
+ else (outputs.hidden_states,)
314
+ )
315
+
316
+ # sample
317
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
318
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
319
+
320
+ # finished sentences should have their next token be a padding token
321
+ if eos_token_id is not None:
322
+ if pad_token_id is None:
323
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
324
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
325
+
326
+ yield next_tokens.cpu()
327
+
328
+ # update generated ids, model inputs, and length for next step
329
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
330
+ if streamer is not None:
331
+ streamer.put(next_tokens.cpu())
332
+
333
+ next_model_inputs = {}
334
+ if "cache_position" in model_inputs:
335
+ next_model_inputs['cache_position'] = model_inputs['cache_position']
336
+ try:
337
+ model_kwargs = self._update_model_kwargs_for_generation(
338
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
339
+ # model_inputs=model_inputs
340
+ model_inputs=next_model_inputs,
341
+ )
342
+ except Exception as e:
343
+ # ! some transformers version don't have model_inputs in generation
344
+ model_kwargs = self._update_model_kwargs_for_generation(
345
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
346
+ # model_inputs=model_inputs
347
+ # model_inputs=next_model_inputs,
348
+ )
349
+
350
+ # if eos_token was found in one sentence, set sentence to finished
351
+ if eos_token_id_tensor is not None:
352
+ unfinished_sequences = unfinished_sequences.mul(
353
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
354
+ )
355
+
356
+ # stop when each sentence is finished
357
+ if unfinished_sequences.max() == 0:
358
+ this_peer_finished = True
359
+
360
+ # stop if we exceed the maximum length
361
+ if stopping_criteria(input_ids, scores):
362
+ this_peer_finished = True
363
+
364
+ if this_peer_finished and not synced_gpus:
365
+ break
366
+
367
+ if streamer is not None:
368
+ streamer.end()
369
+
370
+ # if return_dict_in_generate:
371
+ # if self.config.is_encoder_decoder:
372
+ # return GenerateEncoderDecoderOutput(
373
+ # sequences=input_ids,
374
+ # scores=scores,
375
+ # logits=raw_logits,
376
+ # encoder_attentions=encoder_attentions,
377
+ # encoder_hidden_states=encoder_hidden_states,
378
+ # decoder_attentions=decoder_attentions,
379
+ # cross_attentions=cross_attentions,
380
+ # decoder_hidden_states=decoder_hidden_states,
381
+ # past_key_values=model_kwargs.get("past_key_values"),
382
+ # )
383
+ # else:
384
+ # return GenerateDecoderOnlyOutput(
385
+ # sequences=input_ids,
386
+ # scores=scores,
387
+ # logits=raw_logits,
388
+ # attentions=decoder_attentions,
389
+ # hidden_states=decoder_hidden_states,
390
+ # past_key_values=model_kwargs.get("past_key_values"),
391
+ # )
392
+ # else:
393
+ # return input_ids
394
+
395
+
396
+
397
+ class TransformersEngine(BaseEngine):
398
+ @property
399
+ def max_position_embeddings(self) -> int:
400
+ return self._model.config.max_position_embeddings
401
+
402
+ @property
403
+ def tokenizer(self):
404
+ return self._tokenizer
405
+
406
+ def load_model(self):
407
+ from transformers import AutoTokenizer, AutoModelForCausalLM
408
+ import sys
409
+ # caution: path[0] is reserved for script path (or '' in REPL)
410
+ # sys.path.append(CODE_PATH)
411
+ self.model_path = model_path = MODEL_PATH
412
+ self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
413
+ self.device_map = DEVICE
414
+ print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}')
415
+
416
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
417
+ assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!"
418
+ self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval()
419
+ self._model.sample_old = self._model.sample
420
+ self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
421
+ print(self._model)
422
+ print(f"{self.max_position_embeddings=}")
423
+
424
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
425
+
426
+ # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
427
+ with torch.no_grad():
428
+ inputs = self.tokenizer(prompt, return_tensors='pt')
429
+ num_tokens = inputs.input_ids.size(1)
430
+
431
+ inputs = inputs.to(self.device_map)
432
+
433
+ generator = self._model.generate(
434
+ **inputs,
435
+ do_sample=True,
436
+ temperature=temperature,
437
+ max_new_tokens=max_tokens,
438
+ pad_token_id=self.processor.tokenizer.pad_token_id,
439
+ )
440
+
441
+ out_tokens = []
442
+ response = None
443
+ for token in generator:
444
+ out_tokens.append(token.item())
445
+ response = self.processor.tokenizer.decode(out_tokens)
446
+ num_tokens += 1
447
+ # print(f"{num_tokens=}", end='\r')
448
+ # sys.stdout.flush()
449
+ yield response, num_tokens
450
+
451
+ if response is not None:
452
+ full_text = prompt + response
453
+ num_tokens = len(self.tokenizer.encode(full_text))
454
+ yield response, num_tokens
multipurpose_chatbot/engines/vllm_engine.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import gradio as gr
5
+ from typing import Any, Iterator
6
+ from typing import Iterator, List, Optional, Tuple
7
+ import filelock
8
+ import glob
9
+ import json
10
+ import time
11
+ from gradio.routes import Request
12
+ from gradio.utils import SyncToAsyncIterator, async_iteration
13
+ from gradio.helpers import special_args
14
+ import anyio
15
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
16
+
17
+ from gradio_client.documentation import document, set_documentation_group
18
+
19
+ from typing import List, Optional, Union, Dict, Tuple
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub import snapshot_download
22
+
23
+ from gradio.components import Button
24
+ from gradio.events import Dependency, EventListenerMethod
25
+
26
+ from .base_engine import BaseEngine
27
+ # @@ environments ================
28
+
29
+ from ..configs import (
30
+ DTYPE,
31
+ TENSOR_PARALLEL,
32
+ MODEL_PATH,
33
+ QUANTIZATION,
34
+ MAX_TOKENS,
35
+ TEMPERATURE,
36
+ FREQUENCE_PENALTY,
37
+ PRESENCE_PENALTY,
38
+ GPU_MEMORY_UTILIZATION,
39
+ STREAM_CHECK_MULTIPLE,
40
+ STREAM_YIELD_MULTIPLE,
41
+
42
+ )
43
+
44
+
45
+ llm = None
46
+ demo = None
47
+
48
+
49
+
50
+ def vllm_abort(self):
51
+ sh = self.llm_engine.scheduler
52
+ for g in (sh.waiting + sh.running + sh.swapped):
53
+ sh.abort_seq_group(g.request_id)
54
+ from vllm.sequence import SequenceStatus
55
+ scheduler = self.llm_engine.scheduler
56
+ for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
57
+ for seq_group in state_queue:
58
+ # if seq_group.request_id == request_id:
59
+ # Remove the sequence group from the state queue.
60
+ state_queue.remove(seq_group)
61
+ for seq in seq_group.seqs:
62
+ if seq.is_finished():
63
+ continue
64
+ scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
65
+
66
+
67
+ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
68
+ from vllm.outputs import RequestOutput
69
+ # Initialize tqdm.
70
+ if use_tqdm:
71
+ num_requests = self.llm_engine.get_num_unfinished_requests()
72
+ pbar = tqdm(total=num_requests, desc="Processed prompts")
73
+ # Run the engine.
74
+ outputs: Dict[str, RequestOutput] = {}
75
+ while self.llm_engine.has_unfinished_requests():
76
+ step_outputs = self.llm_engine.step()
77
+ for output in step_outputs:
78
+ outputs[output.request_id] = output
79
+ if len(outputs) > 0:
80
+ yield outputs
81
+
82
+
83
+ def vllm_generate_stream(
84
+ self: Any,
85
+ prompts: Optional[Union[str, List[str]]] = None,
86
+ sampling_params: Optional[Any] = None,
87
+ prompt_token_ids: Optional[List[List[int]]] = None,
88
+ use_tqdm: bool = False,
89
+ ) -> Dict[str, Any]:
90
+ """Generates the completions for the input prompts.
91
+
92
+ NOTE: This class automatically batches the given prompts, considering
93
+ the memory constraint. For the best performance, put all of your prompts
94
+ into a single list and pass it to this method.
95
+
96
+ Args:
97
+ prompts: A list of prompts to generate completions for.
98
+ sampling_params: The sampling parameters for text generation. If
99
+ None, we use the default sampling parameters.
100
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
101
+ use the tokenizer to convert the prompts to token IDs.
102
+ use_tqdm: Whether to use tqdm to display the progress bar.
103
+
104
+ Returns:
105
+ A list of `RequestOutput` objects containing the generated
106
+ completions in the same order as the input prompts.
107
+ """
108
+ from vllm import LLM, SamplingParams
109
+ if prompts is None and prompt_token_ids is None:
110
+ raise ValueError("Either prompts or prompt_token_ids must be "
111
+ "provided.")
112
+ if isinstance(prompts, str):
113
+ # Convert a single prompt to a list.
114
+ prompts = [prompts]
115
+ if prompts is not None and prompt_token_ids is not None:
116
+ if len(prompts) != len(prompt_token_ids):
117
+ raise ValueError("The lengths of prompts and prompt_token_ids "
118
+ "must be the same.")
119
+ if sampling_params is None:
120
+ # Use default sampling params.
121
+ sampling_params = SamplingParams()
122
+ # Add requests to the engine.
123
+ if prompts is not None:
124
+ num_requests = len(prompts)
125
+ else:
126
+ num_requests = len(prompt_token_ids)
127
+ for i in range(num_requests):
128
+ prompt = prompts[i] if prompts is not None else None
129
+ if prompt_token_ids is None:
130
+ token_ids = None
131
+ else:
132
+ token_ids = prompt_token_ids[i]
133
+ self._add_request(prompt, sampling_params, token_ids)
134
+ # return self._run_engine(use_tqdm)
135
+ yield from _vllm_run_engine(self, use_tqdm)
136
+
137
+
138
+
139
+ class VllmEngine(BaseEngine):
140
+ def __init__(self, **kwargs) -> None:
141
+ super().__init__(**kwargs)
142
+
143
+ @property
144
+ def tokenizer(self):
145
+ return self._model.get_tokenizer()
146
+
147
+ def load_model(self, ):
148
+ import torch
149
+ try:
150
+ compute_capability = torch.cuda.get_device_capability()
151
+ print(f'Torch CUDA compute_capability: {compute_capability}')
152
+ except Exception as e:
153
+ print(f'Failed to print compute_capability version: {e}')
154
+
155
+ import vllm
156
+ from vllm import LLM
157
+
158
+ print(f'VLLM: {vllm.__version__=}')
159
+
160
+ if QUANTIZATION == 'awq':
161
+ print(F'Load model in int4 quantization')
162
+ llm = LLM(
163
+ model=MODEL_PATH,
164
+ dtype="float16",
165
+ tensor_parallel_size=TENSOR_PARALLEL,
166
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
167
+ quantization="awq",
168
+ max_model_len=MAX_TOKENS
169
+ )
170
+ else:
171
+ llm = LLM(
172
+ model=MODEL_PATH,
173
+ dtype=DTYPE,
174
+ tensor_parallel_size=TENSOR_PARALLEL,
175
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
176
+ max_model_len=MAX_TOKENS
177
+ )
178
+
179
+ try:
180
+ print(llm.llm_engine.workers[0].model)
181
+ except Exception as e:
182
+ print(f'Cannot print model worker: {e}')
183
+
184
+ try:
185
+ llm.llm_engine.scheduler_config.max_model_len = MAX_TOKENS
186
+ llm.llm_engine.scheduler_config.max_num_batched_tokens = MAX_TOKENS
187
+ except Exception as e:
188
+ print(f'Cannot set parameters: {e}')
189
+
190
+ self._model = llm
191
+
192
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
193
+ from vllm import SamplingParams
194
+ # ! must abort previous ones
195
+ vllm_abort(llm)
196
+ sampling_params = SamplingParams(
197
+ temperature=temperature,
198
+ max_tokens=max_tokens,
199
+ # frequency_penalty=frequency_penalty,
200
+ # presence_penalty=presence_penalty,
201
+ stop=stop_strings,
202
+ )
203
+ cur_out = None
204
+ num_tokens = len(self.tokenizer.encode(prompt))
205
+ for j, gen in enumerate(vllm_generate_stream(llm, prompt, sampling_params)):
206
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
207
+ yield cur_out, num_tokens
208
+ assert len(gen) == 1, f'{gen}'
209
+ item = next(iter(gen.values()))
210
+ cur_out = item.outputs[0].text
211
+
212
+ if cur_out is not None:
213
+ full_text = prompt + cur_out
214
+ num_tokens = len(self.tokenizer.encode(full_text))
215
+ yield cur_out, num_tokens
216
+
217
+ def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
218
+ """
219
+ Only vllm should support this, the other engines is only batch=1 only
220
+ """
221
+ from vllm import SamplingParams
222
+ # ! must abort previous ones
223
+ vllm_abort(llm)
224
+ sampling_params = SamplingParams(
225
+ temperature=temperature,
226
+ max_tokens=max_tokens,
227
+ # frequency_penalty=frequency_penalty,
228
+ # presence_penalty=presence_penalty,
229
+ stop=stop_strings,
230
+ )
231
+ generated = llm.generate(prompts, sampling_params, use_tqdm=False)
232
+ responses = [g.outputs[0].text for g in generated]
233
+ return responses
multipurpose_chatbot/globals.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ global MODEL_ENGINE
4
+
5
+ from multipurpose_chatbot.engines import load_multipurpose_chatbot_engine
6
+ from multipurpose_chatbot.demos import get_demo_class
7
+
8
+ from .configs import (
9
+ BACKEND,
10
+ RAG_EMBED_MODEL_NAME,
11
+ )
12
+
13
+ MODEL_ENGINE = load_multipurpose_chatbot_engine(BACKEND)
14
+
15
+
16
+ RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
17
+
18
+
19
+ def load_embeddings():
20
+ global RAG_EMBED
21
+ if RAG_EMBED is None:
22
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
23
+ print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
24
+ RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
25
+ else:
26
+ print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
27
+ return RAG_EMBED
28
+
29
+
30
+ def get_rag_embeddings():
31
+ return load_embeddings()
32
+
33
+
pyproject.toml ADDED
File without changes
requirements.txt CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  sentencepiece
2
  accelerate
3
  evaluate
@@ -10,21 +21,8 @@ jiwer
10
  tenacity
11
  pynvml
12
  ninja
13
- ray
14
- psutil
15
  fastapi
16
  geomloss
17
  einops
18
  langdetect
19
- transformers
20
- transformers_stream_generator
21
  plotly
22
- vllm
23
- langchain
24
- langchain-community
25
- langchain-core
26
- sentence-transformers
27
- faiss-cpu
28
- pypdf
29
- sentencepiece
30
- docx2txt
 
1
+ torch
2
+ gradio
3
+ tiktoken
4
+ openai
5
+ transformers
6
+ langchain
7
+ langchain-community
8
+ langchain-core
9
+ chromadb
10
+ pypdf
11
+ docx2txt
12
  sentencepiece
13
  accelerate
14
  evaluate
 
21
  tenacity
22
  pynvml
23
  ninja
 
 
24
  fastapi
25
  geomloss
26
  einops
27
  langdetect
 
 
28
  plotly
 
 
 
 
 
 
 
 
 
seallm_app.py ADDED
@@ -0,0 +1,1787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright: DAMO Academy, Alibaba Group
2
+ # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
3
+
4
+ # Description:
5
+ """
6
+ VLLM-based demo script to launch Language chat model for Southeast Asian Languages
7
+ """
8
+
9
+
10
+ import os
11
+ import numpy as np
12
+ import argparse
13
+ import torch
14
+ import gradio as gr
15
+ from typing import Any, Iterator
16
+ from typing import Iterator, List, Optional, Tuple
17
+ import filelock
18
+ import glob
19
+ import json
20
+ import time
21
+ from gradio.routes import Request
22
+ from gradio.utils import SyncToAsyncIterator, async_iteration
23
+ from gradio.helpers import special_args
24
+ import anyio
25
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
26
+
27
+ from gradio_client.documentation import document, set_documentation_group
28
+
29
+ from typing import List, Optional, Union, Dict, Tuple
30
+ from tqdm.auto import tqdm
31
+ from huggingface_hub import snapshot_download
32
+
33
+
34
+ # @@ environments ================
35
+
36
+ DEBUG = bool(int(os.environ.get("DEBUG", "1")))
37
+
38
+ # List of languages to block
39
+ BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
40
+ BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
41
+
42
+ # for lang block, wether to block in history too
43
+ LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
44
+ TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
45
+ DTYPE = os.environ.get("DTYPE", "bfloat16")
46
+
47
+ # ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
48
+ DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
49
+ LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
50
+ # ! show model path in the demo page, only for internal
51
+ DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1")))
52
+
53
+ # ! uploaded model path, will be downloaded to MODEL_PATH
54
+ HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
55
+ # ! if model is private, need HF_TOKEN to access the model
56
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
57
+ # ! path where the model is downloaded, either on ./ or persistent disc
58
+ MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
59
+
60
+ # ! log path
61
+ LOG_PATH = os.environ.get("LOG_PATH", "").strip()
62
+ LOG_FILE = None
63
+ SAVE_LOGS = LOG_PATH is not None and LOG_PATH != ''
64
+ if SAVE_LOGS:
65
+ if os.path.exists(LOG_PATH):
66
+ print(f'LOG_PATH exist: {LOG_PATH}')
67
+ else:
68
+ LOG_DIR = os.path.dirname(LOG_PATH)
69
+ os.makedirs(LOG_DIR, exist_ok=True)
70
+
71
+ # ! get LOG_PATH as aggregated outputs in log
72
+ GET_LOG_CMD = os.environ.get("GET_LOG_CMD", "").strip()
73
+
74
+ print(f'SAVE_LOGS: {SAVE_LOGS} | {LOG_PATH}')
75
+ # print(f'GET_LOG_CMD: {GET_LOG_CMD}')
76
+
77
+ # ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
78
+ DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
79
+ IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
80
+ print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}')
81
+
82
+ # ! list of keywords to disabled as security measures to comply with local regulation
83
+ KEYWORDS = os.environ.get("KEYWORDS", "").strip()
84
+ KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
85
+ KEYWORDS = [x.lower() for x in KEYWORDS]
86
+
87
+ # bypass
88
+ BYPASS_USERS = os.environ.get("BYPASS_USERS", "").strip()
89
+ BYPASS_USERS = BYPASS_USERS.split(";") if len(BYPASS_USERS) > 0 else []
90
+
91
+ # gradio config
92
+ PORT = int(os.environ.get("PORT", "7860"))
93
+ # how many iterations to yield response
94
+ STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
95
+ # how many iterations to perform safety check on response
96
+ STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
97
+
98
+ # whether to enable to popup accept user
99
+ ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
100
+
101
+ # self explanatory
102
+ MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
103
+ TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
104
+ FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.1"))
105
+ PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
106
+ gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
107
+
108
+ # whether to enable quantization, currently not in use
109
+ QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
110
+
111
+
112
+ # Batch inference file upload
113
+ ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
114
+ BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
115
+ BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
116
+ BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
117
+ BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
118
+
119
+ #
120
+ DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
121
+ DATA_SET_REPO = None
122
+
123
+ """
124
+ Internal instructions of how to configure the DEMO
125
+
126
+ 1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
127
+ 2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
128
+ 3. space config env: `HF_MODEL_NAME=SeaLLMs/seal-13b-chat-a` or the underlining model
129
+ 4. If enable persistent storage: set
130
+ HF_HOME=/data/.huggingface
131
+ MODEL_PATH=/data/.huggingface/seal-13b-chat-a
132
+ if not:
133
+ MODEL_PATH=./seal-13b-chat-a
134
+
135
+
136
+ HF_HOME=/data/.huggingface
137
+ MODEL_PATH=/data/ckpt/seal-13b-chat-a
138
+ DELETE_FOLDER=/data/
139
+
140
+ """
141
+
142
+ # ==============================
143
+ print(f'DEBUG mode: {DEBUG}')
144
+ print(f'Torch version: {torch.__version__}')
145
+ try:
146
+ print(f'Torch CUDA version: {torch.version.cuda}')
147
+ except Exception as e:
148
+ print(f'Failed to print cuda version: {e}')
149
+
150
+ try:
151
+ compute_capability = torch.cuda.get_device_capability()
152
+ print(f'Torch CUDA compute_capability: {compute_capability}')
153
+ except Exception as e:
154
+ print(f'Failed to print compute_capability version: {e}')
155
+
156
+
157
+ # @@ constants ================
158
+
159
+ DTYPES = {
160
+ 'float16': torch.float16,
161
+ 'bfloat16': torch.bfloat16
162
+ }
163
+
164
+ llm = None
165
+ demo = None
166
+
167
+
168
+ BOS_TOKEN = '<s>'
169
+ EOS_TOKEN = '</s>'
170
+
171
+
172
+ SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
173
+
174
+
175
+
176
+ # ######### RAG PREPARE
177
+ RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
178
+
179
+ # RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
180
+ RAG_EMBED_MODEL_NAME = "sentence-transformers/LaBSE"
181
+
182
+
183
+ def load_embeddings():
184
+ global RAG_EMBED
185
+ if RAG_EMBED is None:
186
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
187
+ print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
188
+ RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
189
+ else:
190
+ print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
191
+ return RAG_EMBED
192
+
193
+
194
+ def get_rag_embeddings():
195
+ return load_embeddings()
196
+
197
+ _ = get_rag_embeddings()
198
+
199
+ RAG_CURRENT_VECTORSTORE = None
200
+
201
+ def load_document_split_vectorstore(file_path):
202
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
203
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
204
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
205
+ from langchain_community.vectorstores import Chroma, FAISS
206
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
207
+ # assert RAG_EMBED is not None
208
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
209
+ if file_path.endswith('.pdf'):
210
+ loader = PyPDFLoader(file_path)
211
+ elif file_path.endswith('.docx'):
212
+ loader = Docx2txtLoader(file_path)
213
+ elif file_path.endswith('.txt'):
214
+ loader = TextLoader(file_path)
215
+ splits = loader.load_and_split(splitter)
216
+ RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
217
+ return RAG_CURRENT_VECTORSTORE
218
+
219
+
220
+ def docs_to_rag_context(docs: List[str]):
221
+ contexts = "\n".join([d.page_content for d in docs])
222
+ context = f"""Answer the following query exclusively based on the information provided in the document above. \
223
+ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
224
+ ###
225
+ {contexts}
226
+ ###
227
+
228
+
229
+ """
230
+ return context
231
+
232
+ def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
233
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
234
+ doc_context = None
235
+ if file_input is not None:
236
+ assert os.path.exists(file_input), f"not found: {file_input}"
237
+ if file_input == RAG_CURRENT_FILE:
238
+ # reuse
239
+ vectorstore = RAG_CURRENT_VECTORSTORE
240
+ print(f'Reuse vectorstore: {file_input}')
241
+ else:
242
+ vectorstore = load_document_split_vectorstore(file_input)
243
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
244
+ RAG_CURRENT_FILE = file_input
245
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
246
+ doc_context = docs_to_rag_context(docs)
247
+ return doc_context
248
+
249
+ # ######### RAG PREPARE
250
+
251
+
252
+ # ============ CONSTANT ============
253
+ # https://github.com/gradio-app/gradio/issues/884
254
+ MODEL_NAME = "SeaLLM-7B"
255
+ MODEL_NAME = str(os.environ.get("MODEL_NAME", "SeaLLM-7B"))
256
+
257
+ MODEL_TITLE = """
258
+ <div class="container" style="
259
+ align-items: center;
260
+ justify-content: center;
261
+ display: flex;
262
+ ">
263
+ <div class="image" >
264
+ <img src="file/seal_logo.png" style="
265
+ max-width: 10em;
266
+ max-height: 5%;
267
+ height: 3em;
268
+ width: 3em;
269
+ float: left;
270
+ margin-left: auto;
271
+ ">
272
+ </div>
273
+ <div class="text" style="
274
+ padding-left: 20px;
275
+ padding-top: 1%;
276
+ float: left;
277
+ ">
278
+ <h1 style="font-size: xx-large">SeaLLMs - Large Language Models for Southeast Asia</h1>
279
+ </div>
280
+ </div>
281
+ """
282
+
283
+ MODEL_TITLE = """
284
+ <img src="file/seal_logo.png" style="
285
+ max-width: 10em;
286
+ max-height: 5%;
287
+ height: 3em;
288
+ width: 3em;
289
+ ">
290
+ <div class="text" style="
291
+ loat: left;
292
+ padding-bottom: 2%;
293
+ ">
294
+ SeaLLMs - Large Language Models for Southeast Asia
295
+ </div>
296
+ """
297
+
298
+ """
299
+ Somehow cannot add image here
300
+ <div class="image" >
301
+ <img src="file/seal_logo.png" style="
302
+ max-width: 10em;
303
+ max-height: 5%;
304
+ height: 3em;
305
+ width: 3em;
306
+ float: left;
307
+ margin-left: auto;
308
+ ">
309
+ </div>
310
+ """
311
+
312
+ MODEL_DESC = f"""
313
+ <div style='display:flex; gap: 0.25rem; '>
314
+ <a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
315
+ <a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
316
+ <a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
317
+ <a href='https://arxiv.org/pdf/2312.00738.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
318
+ </div>
319
+ <span style="font-size: larger">
320
+ <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">{MODEL_NAME}-v2</a> - a helpful assistant for Southeast Asian Languages 🇬🇧 🇻🇳 🇮🇩 🇹🇭 🇲🇾 🇰🇭 🇱🇦 🇵🇭 🇲🇲.
321
+ Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">our article</a> for more.
322
+ </span>
323
+ <br>
324
+ <span>
325
+ <span style="color: red">NOTE: The chatbot may produce false and harmful content and does not have up-to-date knowledge.</span>
326
+ By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
327
+ not to use our service to generate any harmful, inappropriate or illegal content.
328
+ The service collects user dialogue data for testing and improvement under
329
+ <a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
330
+ </span>
331
+ """.strip()
332
+
333
+
334
+ cite_markdown = """
335
+ ## Citation
336
+ If you find our project useful, hope you can star our repo and cite our paper as follows:
337
+ ```
338
+ @article{damonlpsg2023seallm,
339
+ author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
340
+ title = {SeaLLMs - Large Language Models for Southeast Asia},
341
+ year = 2023,
342
+ }
343
+ ```
344
+ """
345
+
346
+ path_markdown = """
347
+ #### Model path:
348
+ {model_path}
349
+ """
350
+
351
+
352
+
353
+ # ! ==================================================================
354
+
355
+ set_documentation_group("component")
356
+
357
+
358
+ RES_PRINTED = False
359
+
360
+
361
+ @document()
362
+ class ChatBot(gr.Chatbot):
363
+ def _postprocess_chat_messages(
364
+ self, chat_message
365
+ ):
366
+ x = super()._postprocess_chat_messages(chat_message)
367
+ # if isinstance(x, str):
368
+ # x = x.strip().replace("\n", "<br>")
369
+ return x
370
+
371
+
372
+ from gradio.components import Button
373
+ from gradio.events import Dependency, EventListenerMethod
374
+
375
+ # replace events so that submit button is disabled during generation, if stop_btn not found
376
+ # this prevent weird behavior
377
+ def _setup_stop_events(
378
+ self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
379
+ ) -> None:
380
+ from gradio.components import State
381
+ event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
382
+ if self.stop_btn and self.is_generator:
383
+ if self.submit_btn:
384
+ for event_trigger in event_triggers:
385
+ event_trigger(
386
+ lambda: (
387
+ Button(visible=False),
388
+ Button(visible=True),
389
+ ),
390
+ None,
391
+ [self.submit_btn, self.stop_btn],
392
+ api_name=False,
393
+ queue=False,
394
+ )
395
+ event_to_cancel.then(
396
+ lambda: (Button(visible=True), Button(visible=False)),
397
+ None,
398
+ [self.submit_btn, self.stop_btn],
399
+ api_name=False,
400
+ queue=False,
401
+ )
402
+ else:
403
+ for event_trigger in event_triggers:
404
+ event_trigger(
405
+ lambda: Button(visible=True),
406
+ None,
407
+ [self.stop_btn],
408
+ api_name=False,
409
+ queue=False,
410
+ )
411
+ event_to_cancel.then(
412
+ lambda: Button(visible=False),
413
+ None,
414
+ [self.stop_btn],
415
+ api_name=False,
416
+ queue=False,
417
+ )
418
+ self.stop_btn.click(
419
+ None,
420
+ None,
421
+ None,
422
+ cancels=event_to_cancel,
423
+ api_name=False,
424
+ )
425
+ else:
426
+ if self.submit_btn:
427
+ for event_trigger in event_triggers:
428
+ event_trigger(
429
+ lambda: Button(interactive=False),
430
+ None,
431
+ [self.submit_btn],
432
+ api_name=False,
433
+ queue=False,
434
+ )
435
+ event_to_cancel.then(
436
+ lambda: Button(interactive=True),
437
+ None,
438
+ [self.submit_btn],
439
+ api_name=False,
440
+ queue=False,
441
+ )
442
+ # upon clear, cancel the submit event as well
443
+ if self.clear_btn:
444
+ self.clear_btn.click(
445
+ lambda: ([], [], None, Button(interactive=True)),
446
+ None,
447
+ [self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
448
+ queue=False,
449
+ api_name=False,
450
+ cancels=event_to_cancel,
451
+ )
452
+
453
+ # TODO: reconfigure clear button as stop and clear button
454
+ def _setup_events(self) -> None:
455
+ from gradio.components import State
456
+ has_on = False
457
+ try:
458
+ from gradio.events import Dependency, EventListenerMethod, on
459
+ has_on = True
460
+ except ImportError as ie:
461
+ has_on = False
462
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
463
+
464
+ def update_time(c_time, chatbot_state):
465
+ # if chatbot_state is empty, register a new conversaion with the current timestamp
466
+ # assert len(chatbot_state) > 0, f'empty chatbot state'
467
+ if len(chatbot_state) <= 1:
468
+ return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
469
+ # elif len(chatbot_state) == 1:
470
+ # # assert chatbot_state[-1][-1] is None, f'invalid [[message, None]] , got {chatbot_state}'
471
+ # return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
472
+ else:
473
+ return c_time, chatbot_state
474
+
475
+ if has_on:
476
+ # new version
477
+ submit_triggers = (
478
+ [self.textbox.submit, self.submit_btn.click]
479
+ if self.submit_btn
480
+ else [self.textbox.submit]
481
+ )
482
+ submit_event = (
483
+ on(
484
+ submit_triggers,
485
+ self._clear_and_save_textbox,
486
+ [self.textbox],
487
+ [self.textbox, self.saved_input],
488
+ api_name=False,
489
+ queue=False,
490
+ )
491
+ .then(
492
+ self._display_input,
493
+ [self.saved_input, self.chatbot_state],
494
+ [self.chatbot, self.chatbot_state],
495
+ api_name=False,
496
+ queue=False,
497
+ )
498
+ .then(
499
+ update_time,
500
+ [self.additional_inputs[-1], self.chatbot_state],
501
+ [self.additional_inputs[-1], self.chatbot_state],
502
+ api_name=False,
503
+ queue=False,
504
+ )
505
+ .then(
506
+ submit_fn,
507
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
508
+ [self.chatbot, self.chatbot_state],
509
+ api_name=False,
510
+ )
511
+ )
512
+ self._setup_stop_events(submit_triggers, submit_event)
513
+ else:
514
+ raise ValueError(f'Better install new gradio version than 3.44.0')
515
+
516
+ if self.retry_btn:
517
+ retry_event = (
518
+ self.retry_btn.click(
519
+ self._delete_prev_fn,
520
+ [self.chatbot_state],
521
+ [self.chatbot, self.saved_input, self.chatbot_state],
522
+ api_name=False,
523
+ queue=False,
524
+ )
525
+ .then(
526
+ self._display_input,
527
+ [self.saved_input, self.chatbot_state],
528
+ [self.chatbot, self.chatbot_state],
529
+ api_name=False,
530
+ queue=False,
531
+ )
532
+ .then(
533
+ submit_fn,
534
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
535
+ [self.chatbot, self.chatbot_state],
536
+ api_name=False,
537
+ )
538
+ )
539
+ self._setup_stop_events([self.retry_btn.click], retry_event)
540
+
541
+ if self.undo_btn:
542
+ self.undo_btn.click(
543
+ self._delete_prev_fn,
544
+ [self.chatbot_state],
545
+ [self.chatbot, self.saved_input, self.chatbot_state],
546
+ api_name=False,
547
+ queue=False,
548
+ ).then(
549
+ lambda x: x,
550
+ [self.saved_input],
551
+ [self.textbox],
552
+ api_name=False,
553
+ queue=False,
554
+ )
555
+
556
+ # Reconfigure clear_btn to stop and clear text box
557
+
558
+
559
+ def _display_input(
560
+ self, message: str, history: List[List[Union[str, None]]]
561
+ ) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
562
+ if message is not None and message.strip() != "":
563
+ history.append([message, None])
564
+ return history, history
565
+
566
+
567
+ async def _stream_fn(
568
+ self,
569
+ message: str,
570
+ history_with_input,
571
+ request: Request,
572
+ *args,
573
+ ) -> AsyncGenerator:
574
+ history = history_with_input[:-1]
575
+ inputs, _, _ = special_args(
576
+ self.fn, inputs=[message, history, *args], request=request
577
+ )
578
+
579
+ if self.is_async:
580
+ generator = self.fn(*inputs)
581
+ else:
582
+ generator = await anyio.to_thread.run_sync(
583
+ self.fn, *inputs, limiter=self.limiter
584
+ )
585
+ generator = SyncToAsyncIterator(generator, self.limiter)
586
+ try:
587
+ first_response = await async_iteration(generator)
588
+ update = history + [[message, first_response]]
589
+ yield update, update
590
+ except StopIteration:
591
+ update = history + [[message, None]]
592
+ yield update, update
593
+ except Exception as e:
594
+ yield history, history
595
+ raise e
596
+
597
+ try:
598
+ async for response in generator:
599
+ update = history + [[message, response]]
600
+ yield update, update
601
+ except Exception as e:
602
+ # if "invalid" in str(e):
603
+ # yield history, history
604
+ # raise e
605
+ # else:
606
+ # raise e
607
+ yield history, history
608
+ raise e
609
+
610
+
611
+
612
+
613
+ # replace
614
+ gr.ChatInterface._setup_stop_events = _setup_stop_events
615
+ gr.ChatInterface._setup_events = _setup_events
616
+ gr.ChatInterface._display_input = _display_input
617
+ gr.ChatInterface._stream_fn = _stream_fn
618
+
619
+
620
+ @document()
621
+ class CustomTabbedInterface(gr.Blocks):
622
+ def __init__(
623
+ self,
624
+ interface_list: list[gr.Interface],
625
+ tab_names: Optional[list[str]] = None,
626
+ title: Optional[str] = None,
627
+ description: Optional[str] = None,
628
+ theme: Optional[gr.Theme] = None,
629
+ analytics_enabled: Optional[bool] = None,
630
+ css: Optional[str] = None,
631
+ ):
632
+ """
633
+ Parameters:
634
+ interface_list: a list of interfaces to be rendered in tabs.
635
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
636
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
637
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
638
+ css: custom css or path to custom css file to apply to entire Blocks
639
+ Returns:
640
+ a Gradio Tabbed Interface for the given interfaces
641
+ """
642
+ super().__init__(
643
+ title=title or "Gradio",
644
+ theme=theme,
645
+ analytics_enabled=analytics_enabled,
646
+ mode="tabbed_interface",
647
+ css=css,
648
+ )
649
+ self.description = description
650
+ if tab_names is None:
651
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
652
+ with self:
653
+ if title:
654
+ gr.Markdown(
655
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
656
+ )
657
+ if description:
658
+ gr.Markdown(description)
659
+ with gr.Tabs():
660
+ for interface, tab_name in zip(interface_list, tab_names):
661
+ with gr.Tab(label=tab_name):
662
+ interface.render()
663
+
664
+
665
+ def vllm_abort(self):
666
+ sh = self.llm_engine.scheduler
667
+ for g in (sh.waiting + sh.running + sh.swapped):
668
+ sh.abort_seq_group(g.request_id)
669
+ from vllm.sequence import SequenceStatus
670
+ scheduler = self.llm_engine.scheduler
671
+ for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
672
+ for seq_group in state_queue:
673
+ # if seq_group.request_id == request_id:
674
+ # Remove the sequence group from the state queue.
675
+ state_queue.remove(seq_group)
676
+ for seq in seq_group.seqs:
677
+ if seq.is_finished():
678
+ continue
679
+ scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
680
+
681
+
682
+ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
683
+ from vllm.outputs import RequestOutput
684
+ # Initialize tqdm.
685
+ if use_tqdm:
686
+ num_requests = self.llm_engine.get_num_unfinished_requests()
687
+ pbar = tqdm(total=num_requests, desc="Processed prompts")
688
+ # Run the engine.
689
+ outputs: Dict[str, RequestOutput] = {}
690
+ while self.llm_engine.has_unfinished_requests():
691
+ step_outputs = self.llm_engine.step()
692
+ for output in step_outputs:
693
+ outputs[output.request_id] = output
694
+ if len(outputs) > 0:
695
+ yield outputs
696
+
697
+
698
+
699
+ def vllm_generate_stream(
700
+ self: Any,
701
+ prompts: Optional[Union[str, List[str]]] = None,
702
+ sampling_params: Optional[Any] = None,
703
+ prompt_token_ids: Optional[List[List[int]]] = None,
704
+ use_tqdm: bool = False,
705
+ ) -> Dict[str, Any]:
706
+ """Generates the completions for the input prompts.
707
+
708
+ NOTE: This class automatically batches the given prompts, considering
709
+ the memory constraint. For the best performance, put all of your prompts
710
+ into a single list and pass it to this method.
711
+
712
+ Args:
713
+ prompts: A list of prompts to generate completions for.
714
+ sampling_params: The sampling parameters for text generation. If
715
+ None, we use the default sampling parameters.
716
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
717
+ use the tokenizer to convert the prompts to token IDs.
718
+ use_tqdm: Whether to use tqdm to display the progress bar.
719
+
720
+ Returns:
721
+ A list of `RequestOutput` objects containing the generated
722
+ completions in the same order as the input prompts.
723
+ """
724
+ from vllm import LLM, SamplingParams
725
+ if prompts is None and prompt_token_ids is None:
726
+ raise ValueError("Either prompts or prompt_token_ids must be "
727
+ "provided.")
728
+ if isinstance(prompts, str):
729
+ # Convert a single prompt to a list.
730
+ prompts = [prompts]
731
+ if prompts is not None and prompt_token_ids is not None:
732
+ if len(prompts) != len(prompt_token_ids):
733
+ raise ValueError("The lengths of prompts and prompt_token_ids "
734
+ "must be the same.")
735
+ if sampling_params is None:
736
+ # Use default sampling params.
737
+ sampling_params = SamplingParams()
738
+
739
+ # Add requests to the engine.
740
+ if prompts is not None:
741
+ num_requests = len(prompts)
742
+ else:
743
+ num_requests = len(prompt_token_ids)
744
+ for i in range(num_requests):
745
+ prompt = prompts[i] if prompts is not None else None
746
+ if prompt_token_ids is None:
747
+ token_ids = None
748
+ else:
749
+ token_ids = prompt_token_ids[i]
750
+ self._add_request(prompt, sampling_params, token_ids)
751
+ # return self._run_engine(use_tqdm)
752
+ yield from _vllm_run_engine(self, use_tqdm)
753
+
754
+
755
+
756
+ # ! avoid saying
757
+ # LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \
758
+ # Please also consider clearing the chat box for a better experience."""
759
+
760
+ # KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
761
+
762
+ LANG_BLOCK_MESSAGE = """Unsupported language."""
763
+
764
+ KEYWORD_BLOCK_MESSAGE = "Invalid request."
765
+
766
+
767
+ def _detect_lang(text):
768
+ # Disable language that may have safety risk
769
+ from langdetect import detect as detect_lang
770
+ dlang = None
771
+ try:
772
+ dlang = detect_lang(text)
773
+ except Exception as e:
774
+ if "No features in text." in str(e):
775
+ return "en"
776
+ else:
777
+ return "zh"
778
+ return dlang
779
+
780
+
781
+ def block_lang(
782
+ message: str,
783
+ history: List[Tuple[str, str]] = None,
784
+ ) -> str:
785
+ # relieve history base block
786
+ if len(BLOCK_LANGS) == 0:
787
+ return False
788
+
789
+ if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
790
+ return True
791
+ else:
792
+ _lang = _detect_lang(message)
793
+ if _lang in BLOCK_LANGS:
794
+ print(f'Detect blocked {_lang}: {message}')
795
+ return True
796
+ else:
797
+ return False
798
+
799
+
800
+ def safety_check(text, history=None, ) -> Optional[str]:
801
+ """
802
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
803
+ This provides an additional security measure to enhance safety and compliance with local regulations.
804
+ """
805
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
806
+ return KEYWORD_BLOCK_MESSAGE
807
+
808
+ if len(BLOCK_LANGS) > 0:
809
+ if block_lang(text, history):
810
+ return LANG_BLOCK_MESSAGE
811
+
812
+ return None
813
+
814
+
815
+
816
+ TURN_TEMPLATE = "<|im_start|>{role}\n{content}</s>"
817
+ TURN_PREFIX = "<|im_start|>{role}\n"
818
+
819
+
820
+ def chatml_chat_convo_format(conversations, add_assistant_prefix: bool, default_system=SYSTEM_PROMPT_1):
821
+ if conversations[0]['role'] != 'system':
822
+ conversations = [{"role": "system", "content": default_system}] + conversations
823
+ text = ''
824
+ for turn_id, turn in enumerate(conversations):
825
+ prompt = TURN_TEMPLATE.format(role=turn['role'], content=turn['content'])
826
+ text += prompt
827
+ if add_assistant_prefix:
828
+ prompt = TURN_PREFIX.format(role='assistant')
829
+ text += prompt
830
+ return text
831
+
832
+
833
+ def chatml_format(message, history=None, system_prompt=None):
834
+ conversations = []
835
+ system_prompt = system_prompt or "You are a helpful assistant."
836
+ if history is not None and len(history) > 0:
837
+ for i, (prompt, res) in enumerate(history):
838
+ conversations.append({"role": "user", "content": prompt.strip()})
839
+ conversations.append({"role": "assistant", "content": res.strip()})
840
+ conversations.append({"role": "user", "content": message.strip()})
841
+ return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
842
+
843
+
844
+ def debug_chat_response_stream_multiturn(message, history):
845
+ message_safety = safety_check(message, history=history)
846
+ if message_safety is not None:
847
+ # yield message_safety
848
+ raise gr.Error(message_safety)
849
+
850
+ message = "This is a debugging message"
851
+ for i in range(len(message)):
852
+ time.sleep(0.05)
853
+ yield message[:i]
854
+
855
+
856
+
857
+ def chat_response_stream_multiturn(
858
+ message: str,
859
+ history: List[Tuple[str, str]],
860
+ temperature: float,
861
+ max_tokens: int,
862
+ frequency_penalty: float,
863
+ presence_penalty: float,
864
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1,
865
+ current_time: Optional[float] = None,
866
+ # profile: Optional[gr.OAuthProfile] = None,
867
+ ) -> str:
868
+ """
869
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
870
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
871
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
872
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
873
+ gr.Textbox(value=sys_prompt, label='System prompt', lines=8, interactive=False),
874
+ gr.Number(value=0, label='current_time', visible=False),
875
+ """
876
+ global LOG_FILE, LOG_PATH
877
+ if DEBUG:
878
+ yield from debug_chat_response_stream_multiturn(message, history)
879
+ return
880
+ from vllm import LLM, SamplingParams
881
+ """Build multi turn
882
+
883
+ message is incoming prompt
884
+ history don't have the current messauge
885
+ """
886
+ global llm, RES_PRINTED
887
+ assert llm is not None
888
+ assert system_prompt.strip() != '', f'system prompt is empty'
889
+ # is_by_pass = False if profile is None else profile.username in BYPASS_USERS
890
+ is_by_pass = False
891
+
892
+ tokenizer = llm.get_tokenizer()
893
+ # force removing all
894
+ vllm_abort(llm)
895
+
896
+ temperature = float(temperature)
897
+ frequency_penalty = float(frequency_penalty)
898
+ max_tokens = int(max_tokens)
899
+
900
+ message = message.strip()
901
+
902
+ if GET_LOG_CMD != "" and message.strip() == GET_LOG_CMD:
903
+ print_log_file()
904
+ yield "Finish printed log. Please clear the chatbox now."
905
+ return
906
+
907
+ if len(message) == 0:
908
+ raise gr.Error("The message cannot be empty!")
909
+
910
+ message_safety = safety_check(message, history=history)
911
+ if message_safety is not None and not is_by_pass:
912
+ # yield message_safety
913
+ raise gr.Error(message_safety)
914
+
915
+ # history will be appended with message later on
916
+
917
+ full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
918
+ print(full_prompt)
919
+
920
+ if len(tokenizer.encode(full_prompt)) >= 4050:
921
+ raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
922
+
923
+ sampling_params = SamplingParams(
924
+ temperature=temperature,
925
+ max_tokens=max_tokens,
926
+ frequency_penalty=frequency_penalty,
927
+ presence_penalty=presence_penalty,
928
+ # stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'],
929
+ stop=['<s>', '</s>', '<|im_start|>', '<|im_end|>'],
930
+ )
931
+ cur_out = None
932
+
933
+ for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
934
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
935
+ # cur_out = cur_out.replace("\\n", "\n")
936
+
937
+ # optionally check safety, and respond
938
+ if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
939
+ message_safety = safety_check(cur_out, history=None)
940
+ if message_safety is not None and not is_by_pass:
941
+ # yield message_safety
942
+ raise gr.Error(message_safety)
943
+ # return
944
+
945
+ yield cur_out
946
+ assert len(gen) == 1, f'{gen}'
947
+ item = next(iter(gen.values()))
948
+ cur_out = item.outputs[0].text
949
+ #cur_out = "Our system is under maintenance, will be back soon!"
950
+ if j >= max_tokens - 2:
951
+ gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
952
+
953
+ # TODO: use current_time to register conversations, accoriding history and cur_out
954
+ history_str = format_conversation(history + [[message, cur_out]])
955
+ print(f'@@@@@@@@@@\n{history_str}\n##########\n')
956
+
957
+ maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
958
+
959
+ if cur_out is not None and "\\n" in cur_out:
960
+ print(f'double slash-n in cur_out:\n{cur_out}')
961
+ cur_out = cur_out.replace("\\n", "\n")
962
+
963
+ if cur_out is not None:
964
+ yield cur_out
965
+
966
+ message_safety = safety_check(cur_out, history=None)
967
+ if message_safety is not None and not is_by_pass:
968
+ # yield message_safety
969
+ raise gr.Error(message_safety)
970
+ # return
971
+
972
+
973
+
974
+ def chat_response_stream_rag_multiturn(
975
+ message: str,
976
+ history: List[Tuple[str, str]],
977
+ file_input: str,
978
+ temperature: float,
979
+ max_tokens: int,
980
+ # frequency_penalty: float,
981
+ # presence_penalty: float,
982
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1,
983
+ current_time: Optional[float] = None,
984
+ rag_num_docs: Optional[int] = 3,
985
+ ):
986
+ message = message.strip()
987
+ frequency_penalty = FREQUENCE_PENALTY
988
+ presence_penalty = PRESENCE_PENALTY
989
+ if len(message) == 0:
990
+ raise gr.Error("The message cannot be empty!")
991
+ doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
992
+ if doc_context is not None:
993
+ message = f"{doc_context}\n\n{message}"
994
+ yield from chat_response_stream_multiturn(
995
+ message, history, temperature, max_tokens, frequency_penalty,
996
+ presence_penalty, system_prompt, current_time
997
+ )
998
+
999
+
1000
+ def debug_generate_free_form_stream(message):
1001
+ output = " This is a debugging message...."
1002
+ for i in range(len(output)):
1003
+ time.sleep(0.05)
1004
+ yield message + output[:i]
1005
+
1006
+
1007
+ def generate_free_form_stream(
1008
+ message: str,
1009
+ temperature: float,
1010
+ max_tokens: int,
1011
+ frequency_penalty: float,
1012
+ presence_penalty: float,
1013
+ stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
1014
+ current_time: Optional[float] = None,
1015
+ ) -> str:
1016
+ global LOG_FILE, LOG_PATH
1017
+ if DEBUG:
1018
+ yield from debug_generate_free_form_stream(message)
1019
+ return
1020
+ from vllm import LLM, SamplingParams
1021
+ """Build multi turn
1022
+ """
1023
+ global llm, RES_PRINTED
1024
+ assert llm is not None
1025
+ tokenizer = llm.get_tokenizer()
1026
+ # force removing all
1027
+ vllm_abort(llm)
1028
+
1029
+ temperature = float(temperature)
1030
+ frequency_penalty = float(frequency_penalty)
1031
+ max_tokens = int(max_tokens)
1032
+
1033
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
1034
+ stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
1035
+
1036
+ sampling_params = SamplingParams(
1037
+ temperature=temperature,
1038
+ max_tokens=max_tokens,
1039
+ frequency_penalty=frequency_penalty,
1040
+ presence_penalty=presence_penalty,
1041
+ stop=stop_strings,
1042
+ # ignore_eos=True,
1043
+ )
1044
+
1045
+ # full_prompt = message
1046
+ if len(message) == 0:
1047
+ raise gr.Error("The message cannot be empty!")
1048
+
1049
+ message_safety = safety_check(message)
1050
+ if message_safety is not None:
1051
+ raise gr.Error(message_safety)
1052
+
1053
+ if len(tokenizer.encode(message)) >= 4050:
1054
+ raise gr.Error(f"Prompt is too long!")
1055
+
1056
+ cur_out = None
1057
+ for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
1058
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
1059
+ # optionally check safety, and respond
1060
+ if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
1061
+ message_safety = safety_check(cur_out, history=None)
1062
+ if message_safety is not None:
1063
+ raise gr.Error(message_safety)
1064
+ yield message + cur_out
1065
+ assert len(gen) == 1, f'{gen}'
1066
+ item = next(iter(gen.values()))
1067
+ cur_out = item.outputs[0].text
1068
+ #cur_out = "Our system is under maintenance, will be back soon!"
1069
+ if j >= max_tokens - 2:
1070
+ gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
1071
+
1072
+ if cur_out is not None:
1073
+ yield message + cur_out
1074
+
1075
+ message_safety = safety_check(message + cur_out, history=None)
1076
+ if message_safety is not None:
1077
+ raise gr.Error(message_safety)
1078
+
1079
+
1080
+
1081
+
1082
+ def maybe_log_conv_file(current_time, history, message, response, **kwargs):
1083
+ global LOG_FILE
1084
+ if LOG_FILE is not None:
1085
+ my_history = history + [[message, response]]
1086
+ obj = {
1087
+ 'key': str(current_time),
1088
+ 'history': my_history
1089
+ }
1090
+ for k, v in kwargs.items():
1091
+ obj[k] = v
1092
+ log_ = json.dumps(obj, ensure_ascii=False)
1093
+ LOG_FILE.write(log_ + "\n")
1094
+ LOG_FILE.flush()
1095
+ print(f'Wrote {obj["key"]} to {LOG_PATH}')
1096
+
1097
+
1098
+ def format_conversation(history):
1099
+ _str = '\n'.join([
1100
+ (
1101
+ f'<<<User>>> {h[0]}\n'
1102
+ f'<<<Asst>>> {h[1]}'
1103
+ )
1104
+ for h in history
1105
+ ])
1106
+ return _str
1107
+
1108
+
1109
+ def aggregate_convos():
1110
+ from datetime import datetime
1111
+ global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1112
+ assert os.path.exists(LOG_PATH), f'{LOG_PATH} not found'
1113
+ convos = None
1114
+ irregular_count = 1
1115
+ with open(LOG_PATH, 'r', encoding='utf-8') as f:
1116
+ convos = {}
1117
+ for i, l in enumerate(f):
1118
+ if l:
1119
+ item = json.loads(l)
1120
+ key = item['key']
1121
+ try:
1122
+ key = float(key)
1123
+ except Exception as e:
1124
+ key = -1
1125
+ if key > 0.0:
1126
+ item_key = datetime.fromtimestamp(key).strftime("%Y-%m-%d %H:%M:%S")
1127
+ else:
1128
+ key = item_key = f'e{irregular_count}'
1129
+ irregular_count += 1
1130
+ item['key'] = item_key
1131
+ convos[key] = item
1132
+ return convos
1133
+
1134
+ def maybe_upload_to_dataset():
1135
+ from datetime import datetime
1136
+ global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1137
+ if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
1138
+ convos = aggregate_convos()
1139
+ AGG_LOG_PATH = LOG_PATH + ".agg.json"
1140
+ with open(AGG_LOG_PATH, 'w', encoding='utf-8') as fo:
1141
+ json.dump(convos, fo, indent=4, ensure_ascii=False)
1142
+ print(f'Saved aggregated json to {AGG_LOG_PATH}')
1143
+ try:
1144
+ from huggingface_hub import upload_file
1145
+ print(f'upload {AGG_LOG_PATH} to {DATA_SET_REPO_PATH}')
1146
+ upload_file(
1147
+ path_or_fileobj=AGG_LOG_PATH,
1148
+ path_in_repo=os.path.basename(AGG_LOG_PATH),
1149
+ repo_id=DATA_SET_REPO_PATH,
1150
+ token=HF_TOKEN,
1151
+ repo_type="dataset",
1152
+ create_pr=True
1153
+ )
1154
+ except Exception as e:
1155
+ print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1156
+
1157
+
1158
+ def print_log_file():
1159
+ global LOG_FILE, LOG_PATH
1160
+ if SAVE_LOGS and os.path.exists(LOG_PATH):
1161
+ with open(LOG_PATH, 'r', encoding='utf-8') as f:
1162
+ convos = aggregate_convos()
1163
+ print(f'Printing log from {LOG_PATH}')
1164
+ items = list(convos.items())
1165
+ for k, v in items[-10:]:
1166
+ history = v.pop('history')
1167
+ print(f'######--{v}--#####')
1168
+ _str = format_conversation(history)
1169
+ print(_str)
1170
+ maybe_upload_to_dataset()
1171
+
1172
+
1173
+ def debug_chat_response_echo(
1174
+ message: str,
1175
+ history: List[Tuple[str, str]],
1176
+ temperature: float = 0.0,
1177
+ max_tokens: int = 4096,
1178
+ frequency_penalty: float = 0.4,
1179
+ presence_penalty: float = 0.0,
1180
+ current_time: Optional[float] = None,
1181
+ system_prompt: str = SYSTEM_PROMPT_1,
1182
+ ) -> str:
1183
+ global LOG_FILE
1184
+ import time
1185
+ time.sleep(0.5)
1186
+
1187
+ if message.strip() == GET_LOG_CMD:
1188
+ print_log_file()
1189
+ yield "Finish printed log."
1190
+ return
1191
+
1192
+ for i in range(len(message)):
1193
+ yield f"repeat: {current_time} {message[:i + 1]}"
1194
+
1195
+ cur_out = f"repeat: {current_time} {message}"
1196
+ maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
1197
+
1198
+
1199
+ def check_model_path(model_path) -> str:
1200
+ assert os.path.exists(model_path), f'{model_path} not found'
1201
+ ckpt_info = "None"
1202
+ if os.path.isdir(model_path):
1203
+ if os.path.exists(f'{model_path}/info.txt'):
1204
+ with open(f'{model_path}/info.txt', 'r') as f:
1205
+ ckpt_info = f.read()
1206
+ print(f'Checkpoint info:\n{ckpt_info}\n-----')
1207
+ else:
1208
+ print(f'info.txt not found in {model_path}')
1209
+ print(f'model path dir: {list(os.listdir(model_path))}')
1210
+
1211
+ return ckpt_info
1212
+
1213
+
1214
+ def maybe_delete_folder():
1215
+ if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
1216
+ import shutil
1217
+ print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
1218
+ for filename in os.listdir(DELETE_FOLDER):
1219
+ file_path = os.path.join(DELETE_FOLDER, filename)
1220
+ try:
1221
+ if os.path.isfile(file_path) or os.path.islink(file_path):
1222
+ os.unlink(file_path)
1223
+ elif os.path.isdir(file_path):
1224
+ shutil.rmtree(file_path)
1225
+ except Exception as e:
1226
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
1227
+
1228
+
1229
+ AGREE_POP_SCRIPTS = """
1230
+ async () => {
1231
+ alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
1232
+ }
1233
+ """
1234
+
1235
+ def debug_file_function(
1236
+ files: Union[str, List[str]],
1237
+ prompt_mode: str,
1238
+ temperature: float,
1239
+ max_tokens: int,
1240
+ frequency_penalty: float,
1241
+ presence_penalty: float,
1242
+ stop_strings: str = "[STOP],<s>,</s>",
1243
+ current_time: Optional[float] = None,
1244
+ ):
1245
+ """This is only for debug purpose"""
1246
+ files = files if isinstance(files, list) else [files]
1247
+ print(files)
1248
+ filenames = [f.name for f in files]
1249
+ all_items = []
1250
+ for fname in filenames:
1251
+ print(f'Reading {fname}')
1252
+ with open(fname, 'r', encoding='utf-8') as f:
1253
+ items = json.load(f)
1254
+ assert isinstance(items, list), f'invalid items from {fname} not list'
1255
+ all_items.extend(items)
1256
+ print(all_items)
1257
+ print(f'{prompt_mode} / {temperature} / {max_tokens}, {frequency_penalty}, {presence_penalty}')
1258
+ save_path = "./test.json"
1259
+ with open(save_path, 'w', encoding='utf-8') as f:
1260
+ json.dump(all_items, f, indent=4, ensure_ascii=False)
1261
+
1262
+ for x in all_items:
1263
+ x['response'] = "Return response"
1264
+
1265
+ print_items = all_items[:1]
1266
+ # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1267
+ return save_path, print_items
1268
+
1269
+
1270
+ def validate_file_item(filename, index, item: Dict[str, str]):
1271
+ """
1272
+ check safety for items in files
1273
+ """
1274
+ message = item['prompt'].strip()
1275
+
1276
+ if len(message) == 0:
1277
+ raise gr.Error(f'Prompt {index} empty')
1278
+
1279
+ message_safety = safety_check(message, history=None)
1280
+ if message_safety is not None:
1281
+ raise gr.Error(f'Prompt {index} invalid: {message_safety}')
1282
+
1283
+ tokenizer = llm.get_tokenizer() if llm is not None else None
1284
+ if tokenizer is None or len(tokenizer.encode(message)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
1285
+ raise gr.Error(f"Prompt {index} too long, should be less than {BATCH_INFER_MAX_PROMPT_TOKENS} tokens")
1286
+
1287
+
1288
+ def read_validate_json_files(files: Union[str, List[str]]):
1289
+ files = files if isinstance(files, list) else [files]
1290
+ filenames = [f.name for f in files]
1291
+ all_items = []
1292
+ for fname in filenames:
1293
+ # check each files
1294
+ print(f'Reading {fname}')
1295
+ with open(fname, 'r', encoding='utf-8') as f:
1296
+ items = json.load(f)
1297
+ assert isinstance(items, list), f'Data {fname} not list'
1298
+ assert all(isinstance(x, dict) for x in items), f'item in input file not list'
1299
+ assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
1300
+
1301
+ for i, x in enumerate(items):
1302
+ validate_file_item(fname, i, x)
1303
+
1304
+ all_items.extend(items)
1305
+
1306
+ if len(all_items) > BATCH_INFER_MAX_ITEMS:
1307
+ raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1308
+
1309
+ return all_items, filenames
1310
+
1311
+
1312
+ def remove_gradio_cache(exclude_names=None):
1313
+ """remove gradio cache to avoid flooding"""
1314
+ import shutil
1315
+ for root, dirs, files in os.walk('/tmp/gradio/'):
1316
+ for f in files:
1317
+ # if not any(f in ef for ef in except_files):
1318
+ if exclude_names is None or not any(ef in f for ef in exclude_names):
1319
+ print(f'Remove: {f}')
1320
+ os.unlink(os.path.join(root, f))
1321
+ # for d in dirs:
1322
+ # # if not any(d in ef for ef in except_files):
1323
+ # if exclude_names is None or not any(ef in d for ef in exclude_names):
1324
+ # print(f'Remove d: {d}')
1325
+ # shutil.rmtree(os.path.join(root, d))
1326
+
1327
+
1328
+ def maybe_upload_batch_set(pred_json_path):
1329
+ global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1330
+
1331
+ if SAVE_LOGS and DATA_SET_REPO_PATH != "":
1332
+ try:
1333
+ from huggingface_hub import upload_file
1334
+ path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
1335
+ print(f'upload {pred_json_path} to {DATA_SET_REPO_PATH}//{path_in_repo}')
1336
+ upload_file(
1337
+ path_or_fileobj=pred_json_path,
1338
+ path_in_repo=path_in_repo,
1339
+ repo_id=DATA_SET_REPO_PATH,
1340
+ token=HF_TOKEN,
1341
+ repo_type="dataset",
1342
+ create_pr=True
1343
+ )
1344
+ except Exception as e:
1345
+ print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1346
+
1347
+
1348
+ def free_form_prompt(prompt, history=None, system_prompt=None):
1349
+ return prompt
1350
+
1351
+ def batch_inference(
1352
+ files: Union[str, List[str]],
1353
+ prompt_mode: str,
1354
+ temperature: float,
1355
+ max_tokens: int,
1356
+ frequency_penalty: float,
1357
+ presence_penalty: float,
1358
+ stop_strings: str = "[STOP],<s>,</s>,<|im_start|>",
1359
+ current_time: Optional[float] = None,
1360
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1
1361
+ ):
1362
+ """
1363
+ Handle file upload batch inference
1364
+
1365
+ """
1366
+ global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
1367
+ if DEBUG:
1368
+ return debug_file_function(
1369
+ files, prompt_mode, temperature, max_tokens,
1370
+ presence_penalty, stop_strings, current_time)
1371
+
1372
+ from vllm import LLM, SamplingParams
1373
+ assert llm is not None
1374
+ # assert system_prompt.strip() != '', f'system prompt is empty'
1375
+
1376
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
1377
+ tokenizer = llm.get_tokenizer()
1378
+ # force removing all
1379
+ # NOTE: need to make sure all cached items are removed!!!!!!!!!
1380
+ vllm_abort(llm)
1381
+
1382
+ temperature = float(temperature)
1383
+ frequency_penalty = float(frequency_penalty)
1384
+ max_tokens = int(max_tokens)
1385
+
1386
+ all_items, filenames = read_validate_json_files(files)
1387
+
1388
+ # remove all items in /tmp/gradio/
1389
+ remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
1390
+
1391
+ if prompt_mode == 'chat':
1392
+ prompt_format_fn = chatml_format
1393
+ elif prompt_mode == 'few-shot':
1394
+ from functools import partial
1395
+ # prompt_format_fn = partial(
1396
+ # chatml_format, include_end_instruct=False
1397
+ # )
1398
+ prompt_format_fn = free_form_prompt
1399
+ else:
1400
+ raise gr.Error(f'Wrong mode {prompt_mode}')
1401
+
1402
+ full_prompts = [
1403
+ prompt_format_fn(
1404
+ x['prompt'], [], sys_prompt=system_prompt
1405
+ )
1406
+ for i, x in enumerate(all_items)
1407
+ ]
1408
+ print(f'{full_prompts[0]}\n')
1409
+
1410
+ if any(len(tokenizer.encode(x)) >= 4090 for x in full_prompts):
1411
+ raise gr.Error(f"Some prompt is too long!")
1412
+
1413
+ stop_seq = list(set(['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] + stop_strings))
1414
+ sampling_params = SamplingParams(
1415
+ temperature=temperature,
1416
+ max_tokens=max_tokens,
1417
+ frequency_penalty=frequency_penalty,
1418
+ presence_penalty=presence_penalty,
1419
+ stop=stop_seq
1420
+ )
1421
+
1422
+ generated = llm.generate(full_prompts, sampling_params, use_tqdm=False)
1423
+ responses = [g.outputs[0].text for g in generated]
1424
+ #responses = ["Our system is under maintenance, will be back soon!" for g in generated]
1425
+ if len(responses) != len(all_items):
1426
+ raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
1427
+
1428
+ for res, item in zip(responses, all_items):
1429
+ item['response'] = res
1430
+
1431
+ save_path = BATCH_INFER_SAVE_TMP_FILE
1432
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
1433
+ with open(save_path, 'w', encoding='utf-8') as f:
1434
+ json.dump(all_items, f, indent=4, ensure_ascii=False)
1435
+
1436
+ # You need to upload save_path as a new timestamp file.
1437
+ maybe_upload_batch_set(save_path)
1438
+
1439
+ print_items = all_items[:2]
1440
+ # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1441
+ return save_path, print_items
1442
+
1443
+
1444
+ # BATCH_INFER_MAX_ITEMS
1445
+ FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
1446
+ each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
1447
+ ```
1448
+ [ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
1449
+ ```
1450
+ """
1451
+
1452
+ CHAT_EXAMPLES = [
1453
+ ["Hãy giải thích thuyết tương đối rộng."],
1454
+ ["Tolong bantu saya menulis email ke lembaga pemerintah untuk mencari dukungan finansial untuk penelitian AI."],
1455
+ ["แนะนำ 10 จุดหมายปลายทางในกรุงเทพฯ"],
1456
+ ]
1457
+
1458
+
1459
+ # performance items
1460
+
1461
+ def create_free_form_generation_demo():
1462
+ global short_model_path
1463
+ max_tokens = MAX_TOKENS
1464
+ temperature = TEMPERATURE
1465
+ frequence_penalty = FREQUENCE_PENALTY
1466
+ presence_penalty = PRESENCE_PENALTY
1467
+
1468
+ introduction = """
1469
+ ### Free-form | Put any context string (like few-shot prompts)
1470
+ """
1471
+
1472
+ with gr.Blocks() as demo_free_form:
1473
+ gr.Markdown(introduction)
1474
+
1475
+ with gr.Row():
1476
+ txt = gr.Textbox(
1477
+ scale=4,
1478
+ lines=16,
1479
+ show_label=False,
1480
+ placeholder="Enter any free form text and submit",
1481
+ container=False,
1482
+ )
1483
+ with gr.Row():
1484
+ free_submit_button = gr.Button('Submit')
1485
+ with gr.Row():
1486
+ temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
1487
+ length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
1488
+ freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
1489
+ pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
1490
+ stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
1491
+
1492
+ free_submit_button.click(
1493
+ generate_free_form_stream,
1494
+ [txt, temp, length, freq_pen, pres_pen, stop_strings],
1495
+ txt
1496
+ )
1497
+ return demo_free_form
1498
+
1499
+
1500
+
1501
+ def create_file_upload_demo():
1502
+ temperature = TEMPERATURE
1503
+ frequence_penalty = FREQUENCE_PENALTY
1504
+ presence_penalty = PRESENCE_PENALTY
1505
+ max_tokens = MAX_TOKENS
1506
+ demo_file_upload = gr.Interface(
1507
+ batch_inference,
1508
+ inputs=[
1509
+ gr.File(file_count='single', file_types=['json']),
1510
+ gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1511
+ gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
1512
+ gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1513
+ gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1514
+ gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1515
+ gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1516
+ gr.Number(value=0, label='current_time', visible=False),
1517
+ ],
1518
+ outputs=[
1519
+ # "file",
1520
+ gr.File(label="Generated file"),
1521
+ # "json"
1522
+ gr.JSON(label='Example outputs (display 2 samples)')
1523
+ ],
1524
+ description=FILE_UPLOAD_DESCRIPTION,
1525
+ allow_flagging=False,
1526
+ examples=[
1527
+ ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
1528
+ ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
1529
+ ],
1530
+ cache_examples=False,
1531
+ )
1532
+ return demo_file_upload
1533
+
1534
+
1535
+ def create_chat_demo(title=None, description=None):
1536
+ sys_prompt = SYSTEM_PROMPT_1
1537
+ max_tokens = MAX_TOKENS
1538
+ temperature = TEMPERATURE
1539
+ frequence_penalty = FREQUENCE_PENALTY
1540
+ presence_penalty = PRESENCE_PENALTY
1541
+
1542
+ demo_chat = gr.ChatInterface(
1543
+ chat_response_stream_multiturn,
1544
+ chatbot=ChatBot(
1545
+ label=MODEL_NAME,
1546
+ bubble_full_width=False,
1547
+ latex_delimiters=[
1548
+ { "left": "$", "right": "$", "display": False},
1549
+ { "left": "$$", "right": "$$", "display": True},
1550
+ ],
1551
+ show_copy_button=True,
1552
+ ),
1553
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
1554
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1555
+ # ! consider preventing the stop button
1556
+ # stop_btn=None,
1557
+ title=title,
1558
+ description=description,
1559
+ additional_inputs=[
1560
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1561
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1562
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1563
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1564
+ gr.Textbox(value=sys_prompt, label='System prompt', lines=4, interactive=False),
1565
+ gr.Number(value=0, label='current_time', visible=False),
1566
+ # ! Remove the system prompt textbox to avoid jailbreaking
1567
+ ],
1568
+ examples=CHAT_EXAMPLES,
1569
+ cache_examples=False
1570
+ )
1571
+ return demo_chat
1572
+
1573
+
1574
+ def upload_file(file):
1575
+ # file_paths = [file.name for file in files]
1576
+ # return file_paths
1577
+ return file.name
1578
+
1579
+
1580
+ RAG_DESCRIPTION = """
1581
+ * Upload a doc below to answer question about it (RAG).
1582
+ * Every question must be explicit and self-contained! Because each prompt will invoke a new RAG retrieval without considering previous conversations.
1583
+ (E.g: Dont prompt "Answer my previous question in details.")
1584
+ """
1585
+
1586
+ def create_chat_demo_rag(title=None, description=None):
1587
+ sys_prompt = SYSTEM_PROMPT_1
1588
+ max_tokens = MAX_TOKENS
1589
+ temperature = TEMPERATURE
1590
+ frequence_penalty = FREQUENCE_PENALTY
1591
+ presence_penalty = PRESENCE_PENALTY
1592
+ description = description or RAG_DESCRIPTION
1593
+
1594
+ # with gr.Blocks(title="RAG") as rag_demo:
1595
+ additional_inputs = [
1596
+ gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
1597
+ # gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
1598
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1599
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1600
+ # gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1601
+ # gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1602
+ gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
1603
+ gr.Number(value=0, label='current_time', visible=False),
1604
+ ]
1605
+
1606
+ demo_rag_chat = gr.ChatInterface(
1607
+ chat_response_stream_rag_multiturn,
1608
+ chatbot=gr.Chatbot(
1609
+ label=MODEL_NAME + "-RAG",
1610
+ bubble_full_width=False,
1611
+ latex_delimiters=[
1612
+ { "left": "$", "right": "$", "display": False},
1613
+ { "left": "$$", "right": "$$", "display": True},
1614
+ ],
1615
+ show_copy_button=True,
1616
+ ),
1617
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
1618
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1619
+ # ! consider preventing the stop button
1620
+ # stop_btn=None,
1621
+ title=title,
1622
+ description=description,
1623
+ additional_inputs=additional_inputs,
1624
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1625
+ # examples=CHAT_EXAMPLES,
1626
+ cache_examples=False
1627
+ )
1628
+ # with demo_rag_chat:
1629
+ # upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
1630
+ # upload_button.upload(upload_file, upload_button, additional_inputs[0])
1631
+
1632
+ # return demo_chat
1633
+ return demo_rag_chat
1634
+
1635
+
1636
+
1637
+ def launch_demo():
1638
+ global demo, llm, DEBUG, LOG_FILE
1639
+ model_desc = MODEL_DESC
1640
+ model_path = MODEL_PATH
1641
+ model_title = MODEL_TITLE
1642
+ hf_model_name = HF_MODEL_NAME
1643
+ tensor_parallel = TENSOR_PARALLEL
1644
+ assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
1645
+ dtype = DTYPE
1646
+ sys_prompt = SYSTEM_PROMPT_1
1647
+ max_tokens = MAX_TOKENS
1648
+ temperature = TEMPERATURE
1649
+ frequence_penalty = FREQUENCE_PENALTY
1650
+ presence_penalty = PRESENCE_PENALTY
1651
+ ckpt_info = "None"
1652
+
1653
+ print(
1654
+ f'Launch config: '
1655
+ f'\n| model_title=`{model_title}` '
1656
+ f'\n| max_tokens={max_tokens} '
1657
+ f'\n| dtype={dtype} '
1658
+ f'\n| tensor_parallel={tensor_parallel} '
1659
+ f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
1660
+ f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
1661
+ f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
1662
+ f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
1663
+ f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
1664
+ f'\n| frequence_penalty={frequence_penalty} '
1665
+ f'\n| presence_penalty={presence_penalty} '
1666
+ f'\n| temperature={temperature} '
1667
+ # f'\n| hf_model_name={hf_model_name} '
1668
+ f'\n| model_path={model_path} '
1669
+ f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
1670
+ f'\n| gpu_memory_utilization={gpu_memory_utilization} '
1671
+ f'\n| LOG_PATH={LOG_PATH} | SAVE_LOGS={SAVE_LOGS} '
1672
+ f'\n| Desc={model_desc}'
1673
+ )
1674
+
1675
+ if DEBUG:
1676
+ model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
1677
+ # response_fn = debug_chat_response_echo
1678
+ response_fn = chat_response_stream_multiturn
1679
+ print(f'Creating in DEBUG MODE')
1680
+ if SAVE_LOGS:
1681
+ LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1682
+ else:
1683
+ # ! load the model
1684
+ maybe_delete_folder()
1685
+
1686
+ if DOWNLOAD_SNAPSHOT:
1687
+ print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
1688
+ if HF_TOKEN is not None:
1689
+ print(f'Load with HF_TOKEN: {HF_TOKEN}')
1690
+ snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN)
1691
+ else:
1692
+ snapshot_download(hf_model_name, local_dir=model_path)
1693
+
1694
+ import vllm
1695
+ from vllm import LLM
1696
+
1697
+ print(F'VLLM: {vllm.__version__}')
1698
+ ckpt_info = check_model_path(model_path)
1699
+
1700
+ print(f'Load path: {model_path} | {ckpt_info}')
1701
+
1702
+ if QUANTIZATION == 'awq':
1703
+ print(F'Load model in int4 quantization')
1704
+ llm = LLM(model=model_path, dtype="float16", tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq", max_model_len=8192)
1705
+ else:
1706
+ llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, max_model_len=8192)
1707
+
1708
+ try:
1709
+ print(llm.llm_engine.workers[0].model)
1710
+ except Exception as e:
1711
+ print(f'Cannot print model worker: {e}')
1712
+
1713
+ try:
1714
+ llm.llm_engine.scheduler_config.max_model_len = 8192
1715
+ llm.llm_engine.scheduler_config.max_num_batched_tokens = 8192
1716
+ # llm.llm_engine.tokenizer.add_special_tokens = False
1717
+ except Exception as e:
1718
+ print(f'Cannot set parameters: {e}')
1719
+
1720
+ print(f'Use system prompt:\n{sys_prompt}')
1721
+
1722
+ response_fn = chat_response_stream_multiturn
1723
+ print(F'respond: {response_fn}')
1724
+
1725
+ if SAVE_LOGS:
1726
+ LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1727
+
1728
+ if ENABLE_BATCH_INFER:
1729
+
1730
+ # demo_file_upload = create_file_upload_demo()
1731
+
1732
+ demo_free_form = create_free_form_generation_demo()
1733
+
1734
+ demo_chat = create_chat_demo()
1735
+ demo_chat_rag = create_chat_demo_rag(description=RAG_DESCRIPTION)
1736
+ descriptions = model_desc
1737
+ if DISPLAY_MODEL_PATH:
1738
+ descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1739
+
1740
+ demo = CustomTabbedInterface(
1741
+ interface_list=[
1742
+ demo_chat,
1743
+ demo_chat_rag,
1744
+ demo_free_form,
1745
+ # demo_file_upload,
1746
+ ],
1747
+ tab_names=[
1748
+ "Chat Interface",
1749
+ "RAG Chat Interface",
1750
+ "Text completion",
1751
+ # "Batch Inference",
1752
+ ],
1753
+ title=f"{model_title}",
1754
+ description=descriptions,
1755
+ )
1756
+ else:
1757
+ descriptions = model_desc
1758
+ if DISPLAY_MODEL_PATH:
1759
+ descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1760
+
1761
+ demo = create_chat_demo(title=f"{model_title}", description=descriptions)
1762
+ demo.title = MODEL_NAME
1763
+
1764
+ with demo:
1765
+ if DATA_SET_REPO_PATH != "":
1766
+ try:
1767
+ from performance_plot import attach_plot_to_demo
1768
+ attach_plot_to_demo(demo)
1769
+ except Exception as e:
1770
+ print(f'Fail to load DEMO plot: {str(e)}')
1771
+
1772
+ gr.Markdown(cite_markdown)
1773
+ if DISPLAY_MODEL_PATH:
1774
+ gr.Markdown(path_markdown.format(model_path=model_path))
1775
+
1776
+ if ENABLE_AGREE_POPUP:
1777
+ demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1778
+
1779
+ # login_btn = gr.LoginButton()
1780
+
1781
+ demo.queue(api_open=False)
1782
+ return demo
1783
+
1784
+
1785
+ if __name__ == "__main__":
1786
+ demo = launch_demo()
1787
+ demo.launch(show_api=False, allowed_paths=["seal_logo.png"])
seammm_2.png ADDED

Git LFS Details

  • SHA256: 9c3087b9a9bcc2835e80b540109a079825bcb1c74fa9c40b64efec488d6bce59
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
transformers_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers
vllm_requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ vllm