koonmania commited on
Commit
4df8249
·
1 Parent(s): 7854560

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +5 -0
  3. LICENSE +201 -0
  4. README.md +185 -7
  5. __init__.py +0 -0
  6. __pycache__/global_vars.cpython-39.pyc +0 -0
  7. __pycache__/utils.cpython-39.pyc +0 -0
  8. app.py +1109 -0
  9. assets/guimode_preview.gif +3 -0
  10. assets/preview.gif +3 -0
  11. assets/preview.png +3 -0
  12. channels.txt +10 -0
  13. chats/__init__.py +0 -0
  14. chats/__pycache__/__init__.cpython-39.pyc +0 -0
  15. chats/__pycache__/alpaca.cpython-39.pyc +0 -0
  16. chats/__pycache__/alpaca_gpt4.cpython-39.pyc +0 -0
  17. chats/__pycache__/alpacoom.cpython-39.pyc +0 -0
  18. chats/__pycache__/baize.cpython-39.pyc +0 -0
  19. chats/__pycache__/central.cpython-39.pyc +0 -0
  20. chats/__pycache__/custom.cpython-39.pyc +0 -0
  21. chats/__pycache__/falcon.cpython-39.pyc +0 -0
  22. chats/__pycache__/flan_alpaca.cpython-39.pyc +0 -0
  23. chats/__pycache__/freewilly.cpython-39.pyc +0 -0
  24. chats/__pycache__/guanaco.cpython-39.pyc +0 -0
  25. chats/__pycache__/koalpaca.cpython-39.pyc +0 -0
  26. chats/__pycache__/llama2.cpython-39.pyc +0 -0
  27. chats/__pycache__/mpt.cpython-39.pyc +0 -0
  28. chats/__pycache__/os_stablelm.cpython-39.pyc +0 -0
  29. chats/__pycache__/post.cpython-39.pyc +0 -0
  30. chats/__pycache__/pre.cpython-39.pyc +0 -0
  31. chats/__pycache__/redpajama.cpython-39.pyc +0 -0
  32. chats/__pycache__/stable_vicuna.cpython-39.pyc +0 -0
  33. chats/__pycache__/stablelm.cpython-39.pyc +0 -0
  34. chats/__pycache__/starchat.cpython-39.pyc +0 -0
  35. chats/__pycache__/utils.cpython-39.pyc +0 -0
  36. chats/__pycache__/vicuna.cpython-39.pyc +0 -0
  37. chats/__pycache__/wizard_coder.cpython-39.pyc +0 -0
  38. chats/__pycache__/wizard_falcon.cpython-39.pyc +0 -0
  39. chats/__pycache__/xgen.cpython-39.pyc +0 -0
  40. chats/alpaca.py +51 -0
  41. chats/alpaca_gpt4.py +51 -0
  42. chats/alpacoom.py +51 -0
  43. chats/baize.py +68 -0
  44. chats/central.py +380 -0
  45. chats/custom.py +65 -0
  46. chats/falcon.py +62 -0
  47. chats/flan_alpaca.py +51 -0
  48. chats/freewilly.py +74 -0
  49. chats/guanaco.py +63 -0
  50. chats/koalpaca.py +51 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/guimode_preview.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/preview.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/preview.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .ipynb_checkpoints
2
+ __pycache__
3
+ nohup.out
4
+ test.py
5
+ .dstack
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,190 @@
1
  ---
2
- title: LLM As Chatbot
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.38.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: LLM-As-Chatbot
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.38.0
 
 
6
  ---
7
+ ## UPDATE
8
+ - **Internet search support**: you can enable **internet search** capability in Gradio application and Discord bot. For gradio, there is a `internet mode` option in the control panel. For discord, you need to specify `--internet` option in your prompt. For both cases, you need a Serper API Key which you can get one from [serper.dev](https://serper.dev/). By signing up, you will get free 2,500 free google searches which is pretty much sufficient for a long-term test.
9
+ - **Discord Bot support**: you can serve any model from the model zoo as Discord Bot. Find how to do this in the instruction section below.
10
+
11
+ # 💬🚀 LLM as a Chatbot Service
12
+
13
+ The purpose of this repository is to let people to use lots of open sourced instruction-following fine-tuned LLM models as a Chatbot service. Because different models behave differently, and different models require differently formmated prompts, I made a very simple library [`Ping Pong`](https://github.com/deep-diver/PingPong) for model agnostic conversation and context managements.
14
+
15
+ Also, I made [`GradioChat`](https://github.com/deep-diver/gradio-chat) UI that has a similar shape to [HuggingChat](https://huggingface.co/chat/) but entirely built in Gradio. Those two projects are fully integrated to power this project.
16
+
17
+ ## Easiest way to try out ( ✅ Gradio, 🚧 Discord Bot )
18
+
19
+ ### Jarvislabs.ai
20
+
21
+ This project has become the one of the default framework at [jarvislabs.ai](https://jarvislabs.ai/). Jarvislabs.ai is one of the cloud GPU VM provider with the cheapest GPU prices. Furthermore, all the weights of the supported popular open source LLMs are pre-downloaded. You don't need to waste of your money and time to wait until download hundreds of GBs to try out a collection of LLMs. In less than 10 minutes, you can try out any model.
22
+ - for further instruction how to run Gradio application, please follow the [official documentation](https://jarvislabs.ai/docs/llmchat) on the `llmchat` framework.
23
+
24
+ ### dstack
25
+
26
+ [`dstack`](https://dstack.ai) is an open-source tool that allows to run LLM-based apps in a a cloud of your choice via single command. `dstack` supports AWS, GCP, Azure, Lambda Cloud, etc.
27
+
28
+ Use the `gradio.dstack.yml` and `discord.dstack.yml` configurations to run the Gradio app and Discord bot via `dstack`.
29
+ - for more details on how to run this repo with `dstack`, read the [official documentation](https://dstack.ai/examples/llmchat) by `dstack`.
30
+
31
+ ## Instructions
32
+
33
+ ### Standalone Gradio app
34
+
35
+ ![](https://i.ibb.co/gW7yKj9/2023-05-26-3-31-06.png)
36
+
37
+ 0. Prerequisites
38
+
39
+ Note that the code only works `Python >= 3.9` and `gradio >= 3.32.0`
40
+
41
+ ```console
42
+ $ conda create -n llm-serve python=3.9
43
+ $ conda activate llm-serve
44
+ ```
45
+
46
+ 1. Install dependencies.
47
+ ```console
48
+ $ cd LLM-As-Chatbot
49
+ $ pip install -r requirements.txt
50
+ ```
51
+
52
+ 2. Run Gradio application
53
+
54
+ There is no required parameter to run the Gradio application. However, there are some small details worth being noted. When `--local-files-only` is set, application won't try to look up the Hugging Face Hub(remote). Instead, it will only use the files already downloaded and cached.
55
+
56
+ Hugging Face libraries stores downloaded contents under `~/.cache` by default, and this application assumes so. However, if you downloaded weights in different location for some reasons, you can set `HF_HOME` environment variable. Find more about the [environment variables here](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables)
57
+
58
+ In order to leverage **internet search** capability, you need Serper API Key. You can set it manually in the control panel or in CLI. When specifying the Serper API Key in CLI, it will be injected into the corresponding UI control. If you don't have it yet, please get one from [serper.dev](https://serper.dev/). By signing up, you will get free 2,500 free google searches which is pretty much sufficient for a long-term test.
59
+
60
+ ```console
61
+ $ python app.py --root-path "" \
62
+ --local-files-only \
63
+ --share \
64
+ --debug \
65
+ --serper-api-key "YOUR SERPER API KEY"
66
+ ```
67
+
68
+ ### Discord Bot
69
+
70
+ ![](https://i.ibb.co/cJ3yDWh/2023-07-14-1-42-23.png)
71
+
72
+ 0. Prerequisites
73
+
74
+ Note that the code only works `Python >= 3.9`
75
+
76
+ ```console
77
+ $ conda create -n llm-serve python=3.9
78
+ $ conda activate llm-serve
79
+ ```
80
+
81
+ 1. Install dependencies.
82
+ ```console
83
+ $ cd LLM-As-Chatbot
84
+ $ pip install -r requirements.txt
85
+ ```
86
+
87
+ 2. Run Discord Bot application. Choose one of the modes in `--mode-[cpu|mps|8bit|4bit|full-gpu]`. `full-gpu` will be choseon by default(`full` means `half` - consider this as a typo to be fixed later).
88
+
89
+ The `--token` is a required parameter, and you can get it from [Discord Developer Portal](https://discord.com/developers/docs/intro). If you have not setup Discord Bot from the Discord Developer Portal yet, please follow [How to Create a Discord Bot Account](https://www.freecodecamp.org/news/create-a-discord-bot-with-python/) section of the tutorial from [freeCodeCamp](https://www.freecodecamp.org/) to get the token.
90
+
91
+ The `--model-name` is a required parameter, and you can look around the list of supported models from [`model_cards.json`](https://github.com/deep-diver/LLM-As-Chatbot/blob/main/model_cards.json).
92
+
93
+ `--max-workers` is a parameter to determine how many requests to be handled concurrently. This simply defines the value of the `ThreadPoolExecutor`.
94
+
95
+ When `--local-files-only` is set, application won't try to look up the Hugging Face Hub(remote). Instead, it will only use the files already downloaded and cached.
96
+
97
+ In order to leverage **internet search** capability, you need Serper API Key. If you don't have it yet, please get one from [serper.dev](https://serper.dev/). By signing up, you will get free 2,500 free google searches which is pretty much sufficient for a long-term test. Once you have the Serper API Key, you can specify it in `--serper-api-key` option.
98
+
99
+ - Hugging Face libraries stores downloaded contents under `~/.cache` by default, and this application assumes so. However, if you downloaded weights in different location for some reasons, you can set `HF_HOME` environment variable. Find more about the [environment variables here](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables)
100
+
101
+ ```console
102
+ $ python discord_app.py --token "DISCORD BOT TOKEN" \
103
+ --model-name "alpaca-lora-7b" \
104
+ --max-workers 1 \
105
+ --mode-[cpu|mps|8bit|4bit|full-gpu] \
106
+ --local_files_only \
107
+ --serper-api-key "YOUR SERPER API KEY"
108
+ ```
109
+
110
+ 4. Supported Discord Bot commands
111
+
112
+ There is no slash commands. The only way to interact with the deployed discord bot is to mention the bot. However, you can pass some special strings while mentioning the bot.
113
+
114
+ - **`@bot_name help`**: it will display a simple help message
115
+ - **`@bot_name model-info`**: it will display the information of the currently selected(deployed) model from the [`model_cards.json`](https://github.com/deep-diver/LLM-As-Chatbot/blob/main/model_cards.json).
116
+ - **`@bot_name default-params`**: it will display the default parameters to be used in model's `generate` method. That is `GenerationConfig`, and it holds parameters such as `temperature`, `top_p`, and so on.
117
+ - **`@bot_name user message --max-new-tokens 512 --temperature 0.9 --top-p 0.75 --do_sample --max-windows 5 --internet`**: all parameters are used to dynamically determine the text geneartion behaviour as in `GenerationConfig` except `max-windows`. The `max-windows` determines how many past conversations to look up as a reference. The default value is set to `3`, but as the conversation goes long, you can increase this value. `--internet` will try to answer to your prompt by aggregating information scraped from google search. To use `--internet` option, you need to specify `--serper-api-key` when booting up the program.
118
+
119
+ ### Context management
120
+
121
+ Different model might have different strategies to manage context, so if you want to know the exact strategies applied to each model, take a look at the [`chats`](https://github.com/deep-diver/LLM-As-Chatbot/tree/main/chats) directory. However, here are the basic ideas that I have come up with initially. I have found long prompts will slow down the generation process a lot eventually, so I thought the prompts should be kept as short as possible while as concise as possible at the same time. In the previous version, I have accumulated all the past conversations, and that didn't go well.
122
+
123
+ - In every turn of the conversation, the past `N` conversations will be kept. Think about the `N` as a hyper-parameter. As an experiment, currently the past 2-3 conversations are only kept for all models.
124
+
125
+ ### Currently supported models
126
+
127
+ <details><summary>Checkout the list of models</summary>
128
+
129
+ - [tloen/alpaca-lora-7b](https://huggingface.co/tloen/alpaca-lora-7b): the original 7B Alpaca-LoRA checkpoint by tloen (updated by 4/4/2022)
130
+ - [LLMs/Alpaca-LoRA-7B-elina](https://huggingface.co/LLMs/Alpaca-LoRA-7B-elina): the 7B Alpaca-LoRA checkpoint by Chansung (updated by 5/1/2022)
131
+ - [LLMs/Alpaca-LoRA-13B-elina](https://huggingface.co/LLMs/Alpaca-LoRA-13B-elina): the 13B Alpaca-LoRA checkpoint by Chansung (updated by 5/1/2022)
132
+ - [LLMs/Alpaca-LoRA-30B-elina](https://huggingface.co/LLMs/Alpaca-LoRA-30B-elina): the 30B Alpaca-LoRA checkpoint by Chansung (updated by 5/1/2022)
133
+ - [LLMs/Alpaca-LoRA-65B-elina](https://huggingface.co/LLMs/Alpaca-LoRA-65B-elina): the 65B Alpaca-LoRA checkpoint by Chansung (updated by 5/1/2022)
134
+ - [LLMs/AlpacaGPT4-LoRA-7B-elina](https://huggingface.co/LLMs/AlpacaGPT4-LoRA-7B-elina): the 7B Alpaca-LoRA checkpoint trained on GPT4 generated Alpaca style dataset by Chansung (updated by 5/1/2022)
135
+ - [LLMs/AlpacaGPT4-LoRA-13B-elina](https://huggingface.co/LLMs/AlpacaGPT4-LoRA-13B-elina): the 13B Alpaca-LoRA checkpoint trained on GPT4 generated Alpaca style dataset by Chansung (updated by 5/1/2022)
136
+ - [stabilityai/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b): StableLM based fine-tuned model
137
+ - [beomi/KoAlpaca-Polyglot-12.8B](https://huggingface.co/beomi/KoAlpaca-Polyglot-12.8B): [Polyglot](https://github.com/EleutherAI/polyglot) based Alpaca style instruction fine-tuned model
138
+ - [declare-lab/flan-alpaca-xl](https://huggingface.co/declare-lab/flan-alpaca-xl): Flan XL(3B) based Alpaca style instruction fine-tuned model.
139
+ - [declare-lab/flan-alpaca-xxl](https://huggingface.co/declare-lab/flan-alpaca-xxl): Flan XXL(11B) based Alpaca style instruction fine-tuned model.
140
+ - [OpenAssistant/stablelm-7b-sft-v7-epoch-3](https://huggingface.co/OpenAssistant/stablelm-7b-sft-v7-epoch-3): StableLM(7B) based OpenAssistant's oasst1 instruction fine-tuned model.
141
+ - [Writer/camel-5b-hf](https://huggingface.co/Writer/camel-5b-hf): Palmyra-base based instruction fine-tuned model. The foundation model and the data are from its creator, [Writer](https://dev.writer.com).
142
+ - [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5-3b-v1.0): T5(3B) based Vicuna style instruction fine-tuned model on SharedGPT by [lm-sys](https://github.com/lm-sys/FastChat)
143
+ - [LLMs/Stable-Vicuna-13B](https://huggingface.co/LLMs/Stable-Vicuna-13B): Stable Vicuna(13B) from Carpel AI and Stability AI. This is not a delta weight, so use it at your own risk. I will make this repo as private soon and add Hugging Face token field.
144
+ - [LLMs/Vicuna-7b-v1.1](https://huggingface.co/LLMs/Vicuna-7b-v1.1): Vicuna(7B) from FastChat. This is not a delta weight, so use it at your own risk. I will make this repo as private soon and add Hugging Face token field.
145
+ - [LLMs/Vicuna-7b-v1.3](https://huggingface.co/lmsys/vicuna-7b-v1.3)
146
+ - [LLMs/Vicuna-13b-v1.1](https://huggingface.co/LLMs/Vicuna-13b-v1.1): Vicuna(13B) from FastChat. This is not a delta weight, so use it at your own risk. I will make this repo as private soon and add Hugging Face token field.
147
+ - [LLMs/Vicuna-13b-v1.3](https://huggingface.co/lmsys/vicuna-13b-v1.3)
148
+ - [LLMs/Vicuna-33b-v1.3](https://huggingface.co/lmsys/vicuna-33b-v1.3)
149
+ - [togethercomputer/RedPajama-INCITE-Chat-7B-v0.1](https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-7B-v0.1): RedPajama INCITE Chat(7B) from Together.
150
+ - [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat): MPT-7B from MOSAIC ML.
151
+ - [mosaicml/mpt-30b-chat](https://huggingface.co/mosaicml/mpt-30b-chat): MPT-30B from MOSAIC ML.
152
+ - [teknium/llama-deus-7b-v3-lora](https://huggingface.co/teknium/llama-deus-7b-v3-lora): LLaMA 7B based Alpaca style instruction fine-tuned model. The only difference between Alpaca is that this model is fine-tuned on more data including Alpaca dataset, GPTeacher, General Instruct, Code Instruct, Roleplay Instruct, Roleplay V2 Instruct, GPT4-LLM Uncensored, Unnatural Instructions, WizardLM Uncensored, CamelAI's 20k Biology, 20k Physics, 20k Chemistry, 50k Math GPT4 Datasets, and CodeAlpaca
153
+ - [HuggingFaceH4/starchat-alpha](https://huggingface.co/HuggingFaceH4/starchat-alpha): Starcoder 15.5B based instruction fine-tuned model. This model is particularly good at answering questions about coding.
154
+ - [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta): Starcoder 15.5B based instruction fine-tuned model. This model is particularly good at answering questions about coding.
155
+ - [LLMs/Vicuna-LoRA-EvolInstruct-7B](https://huggingface.co/LLMs/Vicuna-LoRA-EvolInstruct-7B): LLaMA 7B based Vicuna style instruction fine-tuned model. The dataset to fine-tune this model is from WizardLM's Evol Instruction dataset.
156
+ - [LLMs/Vicuna-LoRA-EvolInstruct-13B](https://huggingface.co/LLMs/Vicuna-LoRA-EvolInstruct-13B): LLaMA 13B based Vicuna style instruction fine-tuned model. The dataset to fine-tune this model is from WizardLM's Evol Instruction dataset.
157
+ - [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b): LLaMA 7B based Baize
158
+ - [project-baize/baize-v2-13b](https://huggingface.co/project-baize/baize-v2-7b): LLaMA 13B based Baize
159
+ - [timdettmers/guanaco-7b](https://huggingface.co/timdettmers/guanaco-7b): LLaMA 7B based Guanaco which is fine-tuned on OASST1 dataset with QLoRA techniques introduced in "QLoRA: Efficient Finetuning of Quantized LLMs" paper.
160
+ - [timdettmers/guanaco-13b](https://huggingface.co/timdettmers/guanaco-13b): LLaMA 13B based Guanaco which is fine-tuned on OASST1 dataset with QLoRA techniques introduced in "QLoRA: Efficient Finetuning of Quantized LLMs" paper.
161
+ - [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged): LLaMA 30B based Guanaco which is fine-tuned on OASST1 dataset with QLoRA techniques introduced in "QLoRA: Efficient Finetuning of Quantized LLMs" paper.
162
+ - [tiiuae/falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct): Falcon 7B based instruction fine-tuned model on Baize, GPT4All, GPTeacher, and RefinedWeb-English datasets.
163
+ - [tiiuae/falcon-40b-instruct](https://huggingface.co/tiiuae/falcon-40b-instruct): Falcon 40B based instruction fine-tuned model on Baize and RefinedWeb-English datasets.
164
+ - [LLMs/WizardLM-13B-V1.0](https://huggingface.co/LLMs/WizardLM-13B-V1.0)
165
+ - [LLMs/WizardLM-30B-V1.0](https://huggingface.co/LLMs/WizardLM-30B-V1.0)
166
+ - [ehartford/Wizard-Vicuna-13B-Uncensored](https://huggingface.co/ehartford/Wizard-Vicuna-13B-Uncensored)
167
+ - [ehartford/Wizard-Vicuna-30B-Uncensored](https://huggingface.co/ehartford/Wizard-Vicuna-30B-Uncensored)
168
+ - [ehartford/samantha-7b](https://huggingface.co/ehartford/samantha-7b)
169
+ - [ehartford/samantha-13b](https://huggingface.co/ehartford/samantha-13b)
170
+ - [ehartford/samantha-33b](https://huggingface.co/ehartford/samantha-33b)
171
+ - [CalderaAI/30B-Lazarus](https://huggingface.co/CalderaAI/30B-Lazarus)
172
+ - [elinas/chronos-13b](https://huggingface.co/elinas/chronos-13b)
173
+ - [elinas/chronos-33b](https://huggingface.co/elinas/chronos-33b)
174
+ - [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
175
+ - [ehartford/WizardLM-Uncensored-Falcon-7b](https://huggingface.co/ehartford/WizardLM-Uncensored-Falcon-7b)
176
+ - [ehartford/WizardLM-Uncensored-Falcon-40b](https://huggingface.co/ehartford/WizardLM-Uncensored-Falcon-40b)
177
+
178
+ </details>
179
+
180
+ ## Todos
181
+
182
+ - [X] Gradio components to control the configurations of the generation
183
+ - [X] Multiple conversation management
184
+ - [X] Internet search capability (by integrating ChromaDB, `intfloat/e5-large-v2`)
185
+ - [ ] Implement server only option w/ FastAPI
186
+
187
+ ## Acknowledgements
188
 
189
+ - I am thankful to [Jarvislabs.ai](https://jarvislabs.ai/) who generously provided free GPU resources to experiment with Alpaca-LoRA deployment and share it to communities to try out.
190
+ - I am thankful to [AI Network](https://www.ainetwork.ai) who generously provided A100(40G) x 8 DGX workstation for fine-tuning and serving the models.
__init__.py ADDED
File without changes
__pycache__/global_vars.cpython-39.pyc ADDED
Binary file (6.04 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (22 kB). View file
 
app.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import copy
5
+ import types
6
+ from os import listdir
7
+ from os.path import isfile, join
8
+ import argparse
9
+ import gradio as gr
10
+ import global_vars
11
+ from chats import central
12
+ from transformers import AutoModelForCausalLM
13
+ from miscs.styles import MODEL_SELECTION_CSS
14
+ from miscs.js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE
15
+ from utils import get_chat_manager, get_global_context
16
+
17
+ from pingpong.pingpong import PingPong
18
+ from pingpong.gradio import GradioAlpacaChatPPManager
19
+ from pingpong.gradio import GradioKoAlpacaChatPPManager
20
+ from pingpong.gradio import GradioStableLMChatPPManager
21
+ from pingpong.gradio import GradioFlanAlpacaChatPPManager
22
+ from pingpong.gradio import GradioOSStableLMChatPPManager
23
+ from pingpong.gradio import GradioVicunaChatPPManager
24
+ from pingpong.gradio import GradioStableVicunaChatPPManager
25
+ from pingpong.gradio import GradioStarChatPPManager
26
+ from pingpong.gradio import GradioMPTChatPPManager
27
+ from pingpong.gradio import GradioRedPajamaChatPPManager
28
+ from pingpong.gradio import GradioBaizeChatPPManager
29
+
30
+ # no cpu for
31
+ # - falcon families (too slow)
32
+
33
+ load_mode_list = ["cpu"]
34
+
35
+ ex_file = open("examples.txt", "r")
36
+ examples = ex_file.read().split("\n")
37
+ ex_btns = []
38
+
39
+ chl_file = open("channels.txt", "r")
40
+ channels = chl_file.read().split("\n")
41
+ channel_btns = []
42
+
43
+ default_ppm = GradioAlpacaChatPPManager()
44
+ default_ppm.ctx = "Context at top"
45
+ default_ppm.pingpongs = [
46
+ PingPong("user input #1...", "bot response #1..."),
47
+ PingPong("user input #2...", "bot response #2..."),
48
+ ]
49
+ chosen_ppm = copy.deepcopy(default_ppm)
50
+
51
+ prompt_styles = {
52
+ "Alpaca": default_ppm,
53
+ "Baize": GradioBaizeChatPPManager(),
54
+ "Koalpaca": GradioKoAlpacaChatPPManager(),
55
+ "MPT": GradioMPTChatPPManager(),
56
+ "OpenAssistant StableLM": GradioOSStableLMChatPPManager(),
57
+ "RedPajama": GradioRedPajamaChatPPManager(),
58
+ "StableVicuna": GradioVicunaChatPPManager(),
59
+ "StableLM": GradioStableLMChatPPManager(),
60
+ "StarChat": GradioStarChatPPManager(),
61
+ "Vicuna": GradioVicunaChatPPManager(),
62
+ }
63
+
64
+ response_configs = [
65
+ f"configs/response_configs/{f}"
66
+ for f in listdir("configs/response_configs")
67
+ if isfile(join("configs/response_configs", f))
68
+ ]
69
+
70
+ summarization_configs = [
71
+ f"configs/summarization_configs/{f}"
72
+ for f in listdir("configs/summarization_configs")
73
+ if isfile(join("configs/summarization_configs", f))
74
+ ]
75
+
76
+ model_info = json.load(open("model_cards.json"))
77
+
78
+ ###
79
+
80
+ def move_to_model_select_view():
81
+ return (
82
+ "move to model select view",
83
+ gr.update(visible=False),
84
+ gr.update(visible=True),
85
+ )
86
+
87
+ def use_chosen_model():
88
+ try:
89
+ test = global_vars.model
90
+ except AttributeError:
91
+ raise gr.Error("There is no previously chosen model")
92
+
93
+ gen_config = global_vars.gen_config
94
+ gen_sum_config = global_vars.gen_config_summarization
95
+
96
+ if global_vars.model_type == "custom":
97
+ ppmanager_type = chosen_ppm
98
+ else:
99
+ ppmanager_type = get_chat_manager(global_vars.model_type)
100
+
101
+ return (
102
+ "Preparation done!",
103
+ gr.update(visible=False),
104
+ gr.update(visible=True),
105
+ gr.update(label=global_vars.model_type),
106
+ {
107
+ "ppmanager_type": ppmanager_type,
108
+ "model_type": global_vars.model_type,
109
+ },
110
+ get_global_context(global_vars.model_type),
111
+ gen_config.temperature,
112
+ gen_config.top_p,
113
+ gen_config.top_k,
114
+ gen_config.repetition_penalty,
115
+ gen_config.max_new_tokens,
116
+ gen_config.num_beams,
117
+ gen_config.use_cache,
118
+ gen_config.do_sample,
119
+ gen_config.eos_token_id,
120
+ gen_config.pad_token_id,
121
+ gen_sum_config.temperature,
122
+ gen_sum_config.top_p,
123
+ gen_sum_config.top_k,
124
+ gen_sum_config.repetition_penalty,
125
+ gen_sum_config.max_new_tokens,
126
+ gen_sum_config.num_beams,
127
+ gen_sum_config.use_cache,
128
+ gen_sum_config.do_sample,
129
+ gen_sum_config.eos_token_id,
130
+ gen_sum_config.pad_token_id,
131
+ )
132
+
133
+ def move_to_byom_view():
134
+ load_mode_list = []
135
+ if global_vars.cuda_availability:
136
+ load_mode_list.extend(["gpu(half)", "gpu(load_in_8bit)", "gpu(load_in_4bit)"])
137
+
138
+ if global_vars.mps_availability:
139
+ load_mode_list.append("apple silicon")
140
+
141
+ load_mode_list.append("cpu")
142
+
143
+ return (
144
+ "move to the byom view",
145
+ gr.update(visible=False),
146
+ gr.update(visible=True),
147
+ gr.update(choices=load_mode_list, value=load_mode_list[0])
148
+ )
149
+
150
+ def prompt_style_change(key):
151
+ ppm = prompt_styles[key]
152
+ ppm.ctx = "Context at top"
153
+ ppm.pingpongs = [
154
+ PingPong("user input #1...", "bot response #1..."),
155
+ PingPong("user input #2...", "bot response #2..."),
156
+ ]
157
+ chosen_ppm = copy.deepcopy(ppm)
158
+ chosen_ppm.ctx = ""
159
+ chosen_ppm.pingpongs = []
160
+
161
+ return ppm.build_prompts()
162
+
163
+ def byom_load(
164
+ base, ckpt, model_cls, tokenizer_cls,
165
+ bos_token_id, eos_token_id, pad_token_id,
166
+ load_mode,
167
+ ):
168
+ # mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu
169
+ global_vars.initialize_globals_byom(
170
+ base, ckpt, model_cls, tokenizer_cls,
171
+ bos_token_id, eos_token_id, pad_token_id,
172
+ True if load_mode == "cpu" else False,
173
+ True if load_mode == "apple silicon" else False,
174
+ True if load_mode == "8bit" else False,
175
+ True if load_mode == "4bit" else False,
176
+ True if load_mode == "gpu(half)" else False,
177
+ )
178
+
179
+ return (
180
+ ""
181
+ )
182
+
183
+ def channel_num(btn_title):
184
+ choice = 0
185
+
186
+ for idx, channel in enumerate(channels):
187
+ if channel == btn_title:
188
+ choice = idx
189
+
190
+ return choice
191
+
192
+
193
+ def set_chatbot(btn, ld, state):
194
+ choice = channel_num(btn)
195
+
196
+ res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
197
+ empty = len(res[choice].pingpongs) == 0
198
+ return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty))
199
+
200
+
201
+ def set_example(btn):
202
+ return btn, gr.update(visible=False)
203
+
204
+
205
+ def set_popup_visibility(ld, example_block):
206
+ return example_block
207
+
208
+
209
+ def move_to_second_view(btn):
210
+ info = model_info[btn]
211
+
212
+ guard_vram = 5 * 1024.
213
+ vram_req_full = int(info["vram(full)"]) + guard_vram
214
+ vram_req_8bit = int(info["vram(8bit)"]) + guard_vram
215
+ vram_req_4bit = int(info["vram(4bit)"]) + guard_vram
216
+
217
+ load_mode_list = []
218
+
219
+ if global_vars.cuda_availability:
220
+ print(f"total vram = {global_vars.available_vrams_mb}")
221
+ print(f"required vram(full={info['vram(full)']}, 8bit={info['vram(8bit)']}, 4bit={info['vram(4bit)']})")
222
+
223
+ if global_vars.available_vrams_mb >= vram_req_full:
224
+ load_mode_list.append("gpu(half)")
225
+
226
+ if global_vars.available_vrams_mb >= vram_req_8bit:
227
+ load_mode_list.append("gpu(load_in_8bit)")
228
+
229
+ if global_vars.available_vrams_mb >= vram_req_4bit:
230
+ load_mode_list.append("gpu(load_in_4bit)")
231
+
232
+ if global_vars.mps_availability:
233
+ load_mode_list.append("apple silicon")
234
+
235
+ load_mode_list.extend(["cpu"])
236
+
237
+ return (
238
+ gr.update(visible=False),
239
+ gr.update(visible=True),
240
+ info["thumb"],
241
+ f"## {btn}",
242
+ f"**Parameters**\n: Approx. {info['parameters']}",
243
+ f"**🤗 Hub(base)**\n: {info['hub(base)']}",
244
+ f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}",
245
+ info['desc'],
246
+ f"""**Min VRAM requirements** :
247
+ | half precision | load_in_8bit | load_in_4bit |
248
+ | ------------------------------------- | ---------------------------------- | ---------------------------------- |
249
+ | {round(vram_req_full/1024., 1)}GiB | {round(vram_req_8bit/1024., 1)}GiB | {round(vram_req_4bit/1024., 1)}GiB |
250
+ """,
251
+ info['default_gen_config'],
252
+ info['example1'],
253
+ info['example2'],
254
+ info['example3'],
255
+ info['example4'],
256
+ info['thumb-tiny'],
257
+ gr.update(choices=load_mode_list, value=load_mode_list[0]),
258
+ "",
259
+ )
260
+
261
+ def move_to_first_view():
262
+ return (gr.update(visible=True), gr.update(visible=False))
263
+
264
+ def download_completed(
265
+ model_name,
266
+ model_base,
267
+ model_ckpt,
268
+ gen_config_path,
269
+ gen_config_sum_path,
270
+ load_mode,
271
+ thumbnail_tiny,
272
+ force_download,
273
+ ):
274
+ global local_files_only
275
+
276
+ tmp_args = types.SimpleNamespace()
277
+ tmp_args.base_url = model_base.split(":")[-1].split("</p")[0].strip()
278
+ tmp_args.ft_ckpt_url = model_ckpt.split(":")[-1].split("</p")[0].strip()
279
+ tmp_args.gen_config_path = gen_config_path
280
+ tmp_args.gen_config_summarization_path = gen_config_sum_path
281
+ tmp_args.force_download_ckpt = force_download
282
+ tmp_args.thumbnail_tiny = thumbnail_tiny
283
+
284
+ tmp_args.mode_cpu = True if load_mode == "cpu" else False
285
+ tmp_args.mode_mps = True if load_mode == "apple silicon" else False
286
+ tmp_args.mode_8bit = True if load_mode == "gpu(load_in_8bit)" else False
287
+ tmp_args.mode_4bit = True if load_mode == "gpu(load_in_4bit)" else False
288
+ tmp_args.mode_full_gpu = True if load_mode == "gpu(half)" else False
289
+ tmp_args.local_files_only = local_files_only
290
+
291
+ try:
292
+ global_vars.initialize_globals(tmp_args)
293
+ except RuntimeError as e:
294
+ raise gr.Error("GPU memory is not enough to load this model.")
295
+
296
+ return "Download completed!"
297
+
298
+ def move_to_third_view():
299
+ gen_config = global_vars.gen_config
300
+ gen_sum_config = global_vars.gen_config_summarization
301
+
302
+ if global_vars.model_type == "custom":
303
+ ppmanager_type = chosen_ppm
304
+ else:
305
+ ppmanager_type = get_chat_manager(global_vars.model_type)
306
+
307
+ return (
308
+ "Preparation done!",
309
+ gr.update(visible=False),
310
+ gr.update(visible=True),
311
+ gr.update(label=global_vars.model_type),
312
+ {
313
+ "ppmanager_type": ppmanager_type,
314
+ "model_type": global_vars.model_type,
315
+ },
316
+ get_global_context(global_vars.model_type),
317
+ gen_config.temperature,
318
+ gen_config.top_p,
319
+ gen_config.top_k,
320
+ gen_config.repetition_penalty,
321
+ gen_config.max_new_tokens,
322
+ gen_config.num_beams,
323
+ gen_config.use_cache,
324
+ gen_config.do_sample,
325
+ gen_config.eos_token_id,
326
+ gen_config.pad_token_id,
327
+ gen_sum_config.temperature,
328
+ gen_sum_config.top_p,
329
+ gen_sum_config.top_k,
330
+ gen_sum_config.repetition_penalty,
331
+ gen_sum_config.max_new_tokens,
332
+ gen_sum_config.num_beams,
333
+ gen_sum_config.use_cache,
334
+ gen_sum_config.do_sample,
335
+ gen_sum_config.eos_token_id,
336
+ gen_sum_config.pad_token_id,
337
+ )
338
+
339
+
340
+ def toggle_inspector(view_selector):
341
+ if view_selector == "with context inspector":
342
+ return gr.update(visible=True)
343
+ else:
344
+ return gr.update(visible=False)
345
+
346
+
347
+ def reset_chat(idx, ld, state):
348
+ res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
349
+ res[idx].pingpongs = []
350
+
351
+ return (
352
+ "",
353
+ [],
354
+ str(res),
355
+ gr.update(visible=True),
356
+ gr.update(interactive=False),
357
+ )
358
+
359
+ def rollback_last(idx, ld, state):
360
+ res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
361
+ last_user_message = res[idx].pingpongs[-1].ping
362
+ res[idx].pingpongs = res[idx].pingpongs[:-1]
363
+
364
+ return (
365
+ last_user_message,
366
+ res[idx].build_uis(),
367
+ str(res),
368
+ gr.update(interactive=False)
369
+ )
370
+
371
+ def gradio_main(args):
372
+ global local_files_only
373
+ local_files_only = args.local_files_only
374
+
375
+ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
376
+ with gr.Column(visible=True, elem_id="landing-container") as landing_view:
377
+ gr.Markdown("# Chat with LLM", elem_classes=["center"])
378
+ with gr.Row(elem_id="landing-container-selection"):
379
+ with gr.Column():
380
+ gr.Markdown("""This is the landing page of the project, [LLM As Chatbot](https://github.com/deep-diver/LLM-As-Chatbot). This appliction is designed for personal use only. A single model will be selected at a time even if you open up a new browser or a tab. As an initial choice, please select one of the following menu""")
381
+
382
+ gr.Markdown("""
383
+ **Bring your own model**: You can chat with arbitrary models. If your own custom model is based on 🤗 Hugging Face's [transformers](https://huggingface.co/docs/transformers/index) library, you will propbably be able to bring it into this application with this menu
384
+
385
+ **Select a model from model pool**: You can chat with one of the popular open source Large Language Model
386
+
387
+ **Use currently selected model**: If you have already selected, but if you came back to this landing page accidently, you can directly go back to the chatting mode with this menu
388
+ """)
389
+
390
+ byom = gr.Button("🫵🏼 Bring your own model", elem_id="go-byom-select", elem_classes=["square", "landing-btn"])
391
+ select_model = gr.Button("🦙 Select a model from model pool", elem_id="go-model-select", elem_classes=["square", "landing-btn"])
392
+ chosen_model = gr.Button("↪️ Use currently selected model", elem_id="go-use-selected-model", elem_classes=["square", "landing-btn"])
393
+
394
+ with gr.Column(elem_id="landing-bottom"):
395
+ progress_view0 = gr.Textbox(label="Progress", elem_classes=["progress-view"])
396
+ gr.Markdown("""[project](https://github.com/deep-diver/LLM-As-Chatbot)
397
+ [developer](https://github.com/deep-diver)
398
+ """, elem_classes=["center"])
399
+
400
+ with gr.Column(visible=False) as model_choice_view:
401
+ gr.Markdown("# Choose a Model", elem_classes=["center"])
402
+ with gr.Row(elem_id="container"):
403
+ with gr.Column():
404
+ gr.Markdown("## ~ 10B Parameters")
405
+ with gr.Row(elem_classes=["sub-container"]):
406
+ with gr.Column(min_width=20):
407
+ t5_vicuna_3b = gr.Button("t5-vicuna-3b", elem_id="t5-vicuna-3b", elem_classes=["square"])
408
+ gr.Markdown("T5 Vicuna", elem_classes=["center"])
409
+
410
+ with gr.Column(min_width=20, visible=False):
411
+ flan3b = gr.Button("flan-3b", elem_id="flan-3b", elem_classes=["square"])
412
+ gr.Markdown("Flan-XL", elem_classes=["center"])
413
+
414
+ # with gr.Column(min_width=20):
415
+ # replit_3b = gr.Button("replit-3b", elem_id="replit-3b", elem_classes=["square"])
416
+ # gr.Markdown("Replit Instruct", elem_classes=["center"])
417
+
418
+ with gr.Column(min_width=20):
419
+ camel5b = gr.Button("camel-5b", elem_id="camel-5b", elem_classes=["square"])
420
+ gr.Markdown("Camel", elem_classes=["center"])
421
+
422
+ with gr.Column(min_width=20):
423
+ alpaca_lora7b = gr.Button("alpaca-lora-7b", elem_id="alpaca-lora-7b", elem_classes=["square"])
424
+ gr.Markdown("Alpaca-LoRA", elem_classes=["center"])
425
+
426
+ with gr.Column(min_width=20):
427
+ stablelm7b = gr.Button("stablelm-7b", elem_id="stablelm-7b", elem_classes=["square"])
428
+ gr.Markdown("StableLM", elem_classes=["center"])
429
+
430
+ with gr.Column(min_width=20, visible=False):
431
+ os_stablelm7b = gr.Button("os-stablelm-7b", elem_id="os-stablelm-7b", elem_classes=["square"])
432
+ gr.Markdown("OA+StableLM", elem_classes=["center"])
433
+
434
+ with gr.Column(min_width=20):
435
+ gpt4_alpaca_7b = gr.Button("gpt4-alpaca-7b", elem_id="gpt4-alpaca-7b", elem_classes=["square"])
436
+ gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"])
437
+
438
+ with gr.Column(min_width=20):
439
+ mpt_7b = gr.Button("mpt-7b", elem_id="mpt-7b", elem_classes=["square"])
440
+ gr.Markdown("MPT", elem_classes=["center"])
441
+
442
+ with gr.Column(min_width=20):
443
+ redpajama_7b = gr.Button("redpajama-7b", elem_id="redpajama-7b", elem_classes=["square"])
444
+ gr.Markdown("RedPajama", elem_classes=["center"])
445
+
446
+ with gr.Column(min_width=20, visible=False):
447
+ redpajama_instruct_7b = gr.Button("redpajama-instruct-7b", elem_id="redpajama-instruct-7b", elem_classes=["square"])
448
+ gr.Markdown("RedPajama Instruct", elem_classes=["center"])
449
+
450
+ with gr.Column(min_width=20):
451
+ vicuna_7b = gr.Button("vicuna-7b", elem_id="vicuna-7b", elem_classes=["square"])
452
+ gr.Markdown("Vicuna", elem_classes=["center"])
453
+
454
+ with gr.Column(min_width=20):
455
+ vicuna_7b_1_3 = gr.Button("vicuna-7b-1-3", elem_id="vicuna-7b-1-3", elem_classes=["square"])
456
+ gr.Markdown("Vicuna 1.3", elem_classes=["center"])
457
+
458
+ with gr.Column(min_width=20):
459
+ llama_deus_7b = gr.Button("llama-deus-7b", elem_id="llama-deus-7b",elem_classes=["square"])
460
+ gr.Markdown("LLaMA Deus", elem_classes=["center"])
461
+
462
+ with gr.Column(min_width=20):
463
+ evolinstruct_vicuna_7b = gr.Button("evolinstruct-vicuna-7b", elem_id="evolinstruct-vicuna-7b", elem_classes=["square"])
464
+ gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"])
465
+
466
+ with gr.Column(min_width=20, visible=False):
467
+ alpacoom_7b = gr.Button("alpacoom-7b", elem_id="alpacoom-7b", elem_classes=["square"])
468
+ gr.Markdown("Alpacoom", elem_classes=["center"])
469
+
470
+ with gr.Column(min_width=20):
471
+ baize_7b = gr.Button("baize-7b", elem_id="baize-7b", elem_classes=["square"])
472
+ gr.Markdown("Baize", elem_classes=["center"])
473
+
474
+ with gr.Column(min_width=20):
475
+ guanaco_7b = gr.Button("guanaco-7b", elem_id="guanaco-7b", elem_classes=["square"])
476
+ gr.Markdown("Guanaco", elem_classes=["center"])
477
+
478
+ with gr.Column(min_width=20):
479
+ falcon_7b = gr.Button("falcon-7b", elem_id="falcon-7b", elem_classes=["square"])
480
+ gr.Markdown("Falcon", elem_classes=["center"])
481
+
482
+ with gr.Column(min_width=20):
483
+ wizard_falcon_7b = gr.Button("wizard-falcon-7b", elem_id="wizard-falcon-7b", elem_classes=["square"])
484
+ gr.Markdown("Wizard Falcon", elem_classes=["center"])
485
+
486
+ with gr.Column(min_width=20):
487
+ airoboros_7b = gr.Button("airoboros-7b", elem_id="airoboros-7b", elem_classes=["square"])
488
+ gr.Markdown("Airoboros", elem_classes=["center"])
489
+
490
+ with gr.Column(min_width=20):
491
+ samantha_7b = gr.Button("samantha-7b", elem_id="samantha-7b", elem_classes=["square"])
492
+ gr.Markdown("Samantha", elem_classes=["center"])
493
+
494
+ with gr.Column(min_width=20):
495
+ openllama_7b = gr.Button("openllama-7b", elem_id="openllama-7b", elem_classes=["square"])
496
+ gr.Markdown("OpenLLaMA", elem_classes=["center"])
497
+
498
+ with gr.Column(min_width=20):
499
+ orcamini_7b = gr.Button("orcamini-7b", elem_id="orcamini-7b", elem_classes=["square"])
500
+ gr.Markdown("Orca Mini", elem_classes=["center"])
501
+
502
+ with gr.Column(min_width=20):
503
+ xgen_7b = gr.Button("xgen-7b", elem_id="xgen-7b", elem_classes=["square"])
504
+ gr.Markdown("XGen", elem_classes=["center"])
505
+
506
+ with gr.Column(min_width=20):
507
+ llama2_7b = gr.Button("llama2-7b", elem_id="llama2-7b", elem_classes=["square"])
508
+ gr.Markdown("LLaMA 2", elem_classes=["center"])
509
+
510
+ gr.Markdown("## ~ 20B Parameters")
511
+ with gr.Row(elem_classes=["sub-container"]):
512
+ with gr.Column(min_width=20, visible=False):
513
+ flan11b = gr.Button("flan-11b", elem_id="flan-11b", elem_classes=["square"])
514
+ gr.Markdown("Flan-XXL", elem_classes=["center"])
515
+
516
+ with gr.Column(min_width=20):
517
+ koalpaca = gr.Button("koalpaca", elem_id="koalpaca", elem_classes=["square"])
518
+ gr.Markdown("koalpaca", elem_classes=["center"])
519
+
520
+ with gr.Column(min_width=20):
521
+ kullm = gr.Button("kullm", elem_id="kullm", elem_classes=["square"])
522
+ gr.Markdown("KULLM", elem_classes=["center"])
523
+
524
+ with gr.Column(min_width=20):
525
+ alpaca_lora13b = gr.Button("alpaca-lora-13b", elem_id="alpaca-lora-13b", elem_classes=["square"])
526
+ gr.Markdown("Alpaca-LoRA", elem_classes=["center"])
527
+
528
+ with gr.Column(min_width=20):
529
+ gpt4_alpaca_13b = gr.Button("gpt4-alpaca-13b", elem_id="gpt4-alpaca-13b", elem_classes=["square"])
530
+ gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"])
531
+
532
+ with gr.Column(min_width=20):
533
+ stable_vicuna_13b = gr.Button("stable-vicuna-13b", elem_id="stable-vicuna-13b", elem_classes=["square"])
534
+ gr.Markdown("Stable-Vicuna", elem_classes=["center"])
535
+
536
+ with gr.Column(min_width=20):
537
+ starchat_15b = gr.Button("starchat-15b", elem_id="starchat-15b", elem_classes=["square"])
538
+ gr.Markdown("StarChat", elem_classes=["center"])
539
+
540
+ with gr.Column(min_width=20):
541
+ starchat_beta_15b = gr.Button("starchat-beta-15b", elem_id="starchat-beta-15b", elem_classes=["square"])
542
+ gr.Markdown("StarChat β", elem_classes=["center"])
543
+
544
+ with gr.Column(min_width=20):
545
+ vicuna_13b = gr.Button("vicuna-13b", elem_id="vicuna-13b", elem_classes=["square"])
546
+ gr.Markdown("Vicuna", elem_classes=["center"])
547
+
548
+ with gr.Column(min_width=20):
549
+ vicuna_13b_1_3 = gr.Button("vicuna-13b-1-3", elem_id="vicuna-13b-1-3", elem_classes=["square"])
550
+ gr.Markdown("Vicuna 1.3", elem_classes=["center"])
551
+
552
+ with gr.Column(min_width=20):
553
+ evolinstruct_vicuna_13b = gr.Button("evolinstruct-vicuna-13b", elem_id="evolinstruct-vicuna-13b", elem_classes=["square"])
554
+ gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"])
555
+
556
+ with gr.Column(min_width=20):
557
+ baize_13b = gr.Button("baize-13b", elem_id="baize-13b", elem_classes=["square"])
558
+ gr.Markdown("Baize", elem_classes=["center"])
559
+
560
+ with gr.Column(min_width=20):
561
+ guanaco_13b = gr.Button("guanaco-13b", elem_id="guanaco-13b", elem_classes=["square"])
562
+ gr.Markdown("Guanaco", elem_classes=["center"])
563
+
564
+ with gr.Column(min_width=20):
565
+ nous_hermes_13b = gr.Button("nous-hermes-13b", elem_id="nous-hermes-13b", elem_classes=["square"])
566
+ gr.Markdown("Nous Hermes", elem_classes=["center"])
567
+
568
+ with gr.Column(min_width=20):
569
+ airoboros_13b = gr.Button("airoboros-13b", elem_id="airoboros-13b", elem_classes=["square"])
570
+ gr.Markdown("Airoboros", elem_classes=["center"])
571
+
572
+ with gr.Column(min_width=20):
573
+ samantha_13b = gr.Button("samantha-13b", elem_id="samantha-13b", elem_classes=["square"])
574
+ gr.Markdown("Samantha", elem_classes=["center"])
575
+
576
+ with gr.Column(min_width=20):
577
+ chronos_13b = gr.Button("chronos-13b", elem_id="chronos-13b", elem_classes=["square"])
578
+ gr.Markdown("Chronos", elem_classes=["center"])
579
+
580
+ with gr.Column(min_width=20):
581
+ wizardlm_13b = gr.Button("wizardlm-13b", elem_id="wizardlm-13b", elem_classes=["square"])
582
+ gr.Markdown("WizardLM", elem_classes=["center"])
583
+
584
+ with gr.Column(min_width=20):
585
+ wizard_vicuna_13b = gr.Button("wizard-vicuna-13b", elem_id="wizard-vicuna-13b", elem_classes=["square"])
586
+ gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"])
587
+
588
+ with gr.Column(min_width=20):
589
+ wizard_coder_15b = gr.Button("wizard-coder-15b", elem_id="wizard-coder-15b", elem_classes=["square"])
590
+ gr.Markdown("Wizard Coder", elem_classes=["center"])
591
+
592
+ with gr.Column(min_width=20):
593
+ openllama_13b = gr.Button("openllama-13b", elem_id="openllama-13b", elem_classes=["square"])
594
+ gr.Markdown("OpenLLaMA", elem_classes=["center"])
595
+
596
+ with gr.Column(min_width=20):
597
+ orcamini_13b = gr.Button("orcamini-13b", elem_id="orcamini-13b", elem_classes=["square"])
598
+ gr.Markdown("Orca Mini", elem_classes=["center"])
599
+
600
+ with gr.Column(min_width=20):
601
+ llama2_13b = gr.Button("llama2-13b", elem_id="llama2-13b", elem_classes=["square"])
602
+ gr.Markdown("LLaMA 2", elem_classes=["center"])
603
+
604
+ with gr.Column(min_width=20):
605
+ nous_hermes_13b_v2 = gr.Button("nous-hermes-13b-llama2", elem_id="nous-hermes-13b-llama2", elem_classes=["square"])
606
+ gr.Markdown("Nous Hermes v2", elem_classes=["center"])
607
+
608
+ gr.Markdown("## ~ 30B Parameters", visible=False)
609
+ with gr.Row(elem_classes=["sub-container"], visible=False):
610
+ with gr.Column(min_width=20):
611
+ camel20b = gr.Button("camel-20b", elem_id="camel-20b", elem_classes=["square"])
612
+ gr.Markdown("Camel", elem_classes=["center"])
613
+
614
+ gr.Markdown("## ~ 40B Parameters")
615
+ with gr.Row(elem_classes=["sub-container"]):
616
+ with gr.Column(min_width=20):
617
+ guanaco_33b = gr.Button("guanaco-33b", elem_id="guanaco-33b", elem_classes=["square"])
618
+ gr.Markdown("Guanaco", elem_classes=["center"])
619
+
620
+ with gr.Column(min_width=20):
621
+ falcon_40b = gr.Button("falcon-40b", elem_id="falcon-40b", elem_classes=["square"])
622
+ gr.Markdown("Falcon", elem_classes=["center"])
623
+
624
+ with gr.Column(min_width=20):
625
+ wizard_falcon_40b = gr.Button("wizard-falcon-40b", elem_id="wizard-falcon-40b", elem_classes=["square"])
626
+ gr.Markdown("Wizard Falcon", elem_classes=["center"])
627
+
628
+ with gr.Column(min_width=20):
629
+ samantha_33b = gr.Button("samantha-33b", elem_id="samantha-33b", elem_classes=["square"])
630
+ gr.Markdown("Samantha", elem_classes=["center"])
631
+
632
+ with gr.Column(min_width=20):
633
+ lazarus_30b = gr.Button("lazarus-30b", elem_id="lazarus-30b", elem_classes=["square"])
634
+ gr.Markdown("Lazarus", elem_classes=["center"])
635
+
636
+ with gr.Column(min_width=20):
637
+ chronos_33b = gr.Button("chronos-33b", elem_id="chronos-33b", elem_classes=["square"])
638
+ gr.Markdown("Chronos", elem_classes=["center"])
639
+
640
+ with gr.Column(min_width=20):
641
+ wizardlm_30b = gr.Button("wizardlm-30b", elem_id="wizardlm-30b", elem_classes=["square"])
642
+ gr.Markdown("WizardLM", elem_classes=["center"])
643
+
644
+ with gr.Column(min_width=20):
645
+ wizard_vicuna_30b = gr.Button("wizard-vicuna-30b", elem_id="wizard-vicuna-30b", elem_classes=["square"])
646
+ gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"])
647
+
648
+ with gr.Column(min_width=20):
649
+ vicuna_33b_1_3 = gr.Button("vicuna-33b-1-3", elem_id="vicuna-33b-1-3", elem_classes=["square"])
650
+ gr.Markdown("Vicuna 1.3", elem_classes=["center"])
651
+
652
+ with gr.Column(min_width=20):
653
+ mpt_30b = gr.Button("mpt-30b", elem_id="mpt-30b", elem_classes=["square"])
654
+ gr.Markdown("MPT", elem_classes=["center"])
655
+
656
+ with gr.Column(min_width=20):
657
+ upstage_llama_30b = gr.Button("upstage-llama-30b", elem_id="upstage-llama-30b", elem_classes=["square"])
658
+ gr.Markdown("Upstage LLaMA", elem_classes=["center"])
659
+
660
+ gr.Markdown("## ~ 70B Parameters")
661
+ with gr.Row(elem_classes=["sub-container"]):
662
+ with gr.Column(min_width=20):
663
+ free_willy2_70b = gr.Button("free-willy2-70b", elem_id="free-willy2-70b", elem_classes=["square"])
664
+ gr.Markdown("Free Willy 2", elem_classes=["center"])
665
+
666
+ progress_view = gr.Textbox(label="Progress", elem_classes=["progress-view"])
667
+
668
+ with gr.Column(visible=False) as byom_input_view:
669
+ with gr.Column(elem_id="container3"):
670
+ gr.Markdown("# Bring Your Own Model", elem_classes=["center"])
671
+
672
+ gr.Markdown("### Model configuration")
673
+ byom_base = gr.Textbox(label="Base", placeholder="Enter path or 🤗 hub ID of the base model", interactive=True)
674
+ byom_ckpt = gr.Textbox(label="LoRA ckpt", placeholder="Enter path or 🤗 hub ID of the LoRA checkpoint", interactive=True)
675
+
676
+ with gr.Accordion("Advanced options", open=False):
677
+ gr.Markdown("If you leave the below textboxes empty, `transformers.AutoModelForCausalLM` and `transformers.AutoTokenizer` classes will be used by default. If you need any specific class, please type them below.")
678
+ byom_model_cls = gr.Textbox(label="Base model class", placeholder="Enter base model class", interactive=True)
679
+ byom_tokenizer_cls = gr.Textbox(label="Base tokenizer class", placeholder="Enter base tokenizer class", interactive=True)
680
+
681
+ with gr.Column():
682
+ gr.Markdown("If you leave the below textboxes empty, any token ids for bos, eos, and pad will not be specified in `GenerationConfig`. If you think that you need to specify them. please type them below in decimal format.")
683
+ with gr.Row():
684
+ byom_bos_token_id = gr.Textbox(label="bos_token_id", placeholder="for GenConfig")
685
+ byom_eos_token_id = gr.Textbox(label="eos_token_id", placeholder="for GenConfig")
686
+ byom_pad_token_id = gr.Textbox(label="pad_token_id", placeholder="for GenConfig")
687
+
688
+ with gr.Row():
689
+ byom_load_mode = gr.Radio(
690
+ load_mode_list,
691
+ value=load_mode_list[0],
692
+ label="load mode",
693
+ elem_classes=["load-mode-selector"]
694
+ )
695
+
696
+ gr.Markdown("### Prompt configuration")
697
+ prompt_style_selector = gr.Dropdown(
698
+ label="Prompt style",
699
+ interactive=True,
700
+ choices=list(prompt_styles.keys()),
701
+ value="Alpaca"
702
+ )
703
+ with gr.Accordion("Prompt style preview", open=False):
704
+ prompt_style_previewer = gr.Textbox(
705
+ label="How prompt is actually structured",
706
+ lines=16,
707
+ value=default_ppm.build_prompts())
708
+
709
+ with gr.Row():
710
+ byom_back_btn = gr.Button("Back")
711
+ byom_confirm_btn = gr.Button("Confirm")
712
+
713
+ with gr.Column(elem_classes=["progress-view"]):
714
+ txt_view3 = gr.Textbox(label="Status")
715
+ progress_view3 = gr.Textbox(label="Progress")
716
+
717
+ with gr.Column(visible=False) as model_review_view:
718
+ gr.Markdown("# Confirm the chosen model", elem_classes=["center"])
719
+
720
+ with gr.Column(elem_id="container2"):
721
+ gr.Markdown("Please expect loading time to be longer than expected. Depending on the size of models, it will probably take from 100 to 1000 seconds or so. Please be patient.")
722
+
723
+ with gr.Row():
724
+ model_image = gr.Image(None, interactive=False, show_label=False)
725
+ with gr.Column():
726
+ model_name = gr.Markdown("**Model name**")
727
+ model_desc = gr.Markdown("...")
728
+ model_params = gr.Markdown("Parameters\n: ...")
729
+ model_base = gr.Markdown("🤗 Hub(base)\n: ...")
730
+ model_ckpt = gr.Markdown("🤗 Hub(LoRA)\n: ...")
731
+ model_vram = gr.Markdown(f"""**Minimal VRAM requirement** :
732
+ | half precision | load_in_8bit | load_in_4bit |
733
+ | ------------------------------ | ------------------------- | ------------------------- |
734
+ | {round(7830/1024., 1)}GiB | {round(5224/1024., 1)}GiB | {round(4324/1024., 1)}GiB |
735
+ """)
736
+ model_thumbnail_tiny = gr.Textbox("", visible=False)
737
+
738
+ with gr.Column():
739
+ gen_config_path = gr.Dropdown(
740
+ response_configs,
741
+ value=response_configs[0],
742
+ interactive=True,
743
+ label="Gen Config(response)",
744
+ )
745
+ gen_config_sum_path = gr.Dropdown(
746
+ summarization_configs,
747
+ value=summarization_configs[0],
748
+ interactive=True,
749
+ label="Gen Config(summarization)",
750
+ visible=False,
751
+ )
752
+ with gr.Row():
753
+ load_mode = gr.Radio(
754
+ load_mode_list,
755
+ value=load_mode_list[0],
756
+ label="load mode",
757
+ elem_classes=["load-mode-selector"]
758
+ )
759
+ force_redownload = gr.Checkbox(label="Force Re-download", interactive=False, visible=False)
760
+
761
+ with gr.Accordion("Example showcases", open=False):
762
+ with gr.Tab("Ex1"):
763
+ example_showcase1 = gr.Chatbot(
764
+ [("hello", "world"), ("damn", "good")]
765
+ )
766
+ with gr.Tab("Ex2"):
767
+ example_showcase2 = gr.Chatbot(
768
+ [("hello", "world"), ("damn", "good")]
769
+ )
770
+ with gr.Tab("Ex3"):
771
+ example_showcase3 = gr.Chatbot(
772
+ [("hello", "world"), ("damn", "good")]
773
+ )
774
+ with gr.Tab("Ex4"):
775
+ example_showcase4 = gr.Chatbot(
776
+ [("hello", "world"), ("damn", "good")]
777
+ )
778
+
779
+ with gr.Row():
780
+ back_to_model_choose_btn = gr.Button("Back")
781
+ confirm_btn = gr.Button("Confirm")
782
+
783
+ with gr.Column(elem_classes=["progress-view"]):
784
+ txt_view = gr.Textbox(label="Status")
785
+ progress_view2 = gr.Textbox(label="Progress")
786
+
787
+ with gr.Column(visible=False) as chat_view:
788
+ idx = gr.State(0)
789
+ chat_state = gr.State()
790
+ local_data = gr.JSON({}, visible=False)
791
+
792
+ with gr.Row():
793
+ with gr.Column(scale=1, min_width=180):
794
+ gr.Markdown("GradioChat", elem_id="left-top")
795
+
796
+ with gr.Column(elem_id="left-pane"):
797
+ chat_back_btn = gr.Button("Back", elem_id="chat-back-btn")
798
+
799
+ with gr.Accordion("Histories", elem_id="chat-history-accordion", open=False):
800
+ channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"]))
801
+
802
+ for channel in channels[1:]:
803
+ channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"]))
804
+
805
+ with gr.Column(scale=8, elem_id="right-pane"):
806
+ with gr.Column(
807
+ elem_id="initial-popup", visible=False
808
+ ) as example_block:
809
+ with gr.Row(scale=1):
810
+ with gr.Column(elem_id="initial-popup-left-pane"):
811
+ gr.Markdown("GradioChat", elem_id="initial-popup-title")
812
+ gr.Markdown("Making the community's best AI chat models available to everyone.")
813
+ with gr.Column(elem_id="initial-popup-right-pane"):
814
+ gr.Markdown("Chat UI is now open sourced on Hugging Face Hub")
815
+ gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)")
816
+
817
+ with gr.Column(scale=1):
818
+ gr.Markdown("Examples")
819
+ with gr.Row():
820
+ for example in examples:
821
+ ex_btns.append(gr.Button(example, elem_classes=["example-btn"]))
822
+
823
+ with gr.Column(elem_id="aux-btns-popup", visible=True):
824
+ with gr.Row():
825
+ stop = gr.Button("Stop", elem_classes=["aux-btn"])
826
+ regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"])
827
+ clean = gr.Button("Clean", elem_classes=["aux-btn"])
828
+
829
+ with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False):
830
+ context_inspector = gr.Textbox(
831
+ "",
832
+ elem_id="aux-viewer-inspector",
833
+ label="",
834
+ lines=30,
835
+ max_lines=50,
836
+ )
837
+
838
+ chatbot = gr.Chatbot(elem_id='chatbot')
839
+ instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt")
840
+
841
+ with gr.Accordion("Control Panel", open=False) as control_panel:
842
+ with gr.Column():
843
+ with gr.Column():
844
+ gr.Markdown("#### Global context")
845
+ with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=False):
846
+ global_context = gr.Textbox(
847
+ "global context",
848
+ lines=5,
849
+ max_lines=10,
850
+ interactive=True,
851
+ elem_id="global-context"
852
+ )
853
+
854
+ gr.Markdown("#### Internet search")
855
+ with gr.Row():
856
+ internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
857
+ serper_api_key = gr.Textbox(
858
+ value= "" if args.serper_api_key is None else args.serper_api_key,
859
+ placeholder="Get one by visiting serper.dev",
860
+ label="Serper api key"
861
+ )
862
+
863
+ gr.Markdown("#### GenConfig for **response** text generation")
864
+ with gr.Row():
865
+ res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
866
+ res_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
867
+ res_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
868
+ res_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
869
+ res_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True)
870
+ res_beams = gr.Slider(1, 4, 0, step=1, label="beams")
871
+ res_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
872
+ res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
873
+ res_eosid = gr.Number(value=0, visible=False, precision=0)
874
+ res_padid = gr.Number(value=0, visible=False, precision=0)
875
+
876
+ with gr.Column(visible=False):
877
+ gr.Markdown("#### GenConfig for **summary** text generation")
878
+ with gr.Row():
879
+ sum_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
880
+ sum_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
881
+ sum_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
882
+ sum_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
883
+ sum_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True)
884
+ sum_beams = gr.Slider(1, 8, 0, step=1, label="beams", interactive=True)
885
+ sum_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
886
+ sum_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
887
+ sum_eosid = gr.Number(value=0, visible=False, precision=0)
888
+ sum_padid = gr.Number(value=0, visible=False, precision=0)
889
+
890
+ with gr.Column():
891
+ gr.Markdown("#### Context managements")
892
+ with gr.Row():
893
+ ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True)
894
+ ctx_sum_prompt = gr.Textbox(
895
+ "summarize our conversations. what have we discussed about so far?",
896
+ label="design a prompt to summarize the conversations",
897
+ visible=False
898
+ )
899
+
900
+ btns = [
901
+ t5_vicuna_3b, flan3b, camel5b, alpaca_lora7b, stablelm7b,
902
+ gpt4_alpaca_7b, os_stablelm7b, mpt_7b, redpajama_7b, redpajama_instruct_7b, llama_deus_7b,
903
+ evolinstruct_vicuna_7b, alpacoom_7b, baize_7b, guanaco_7b, vicuna_7b_1_3,
904
+ falcon_7b, wizard_falcon_7b, airoboros_7b, samantha_7b, openllama_7b, orcamini_7b,
905
+ xgen_7b,llama2_7b,
906
+ flan11b, koalpaca, kullm, alpaca_lora13b, gpt4_alpaca_13b, stable_vicuna_13b,
907
+ starchat_15b, starchat_beta_15b, vicuna_7b, vicuna_13b, evolinstruct_vicuna_13b,
908
+ baize_13b, guanaco_13b, nous_hermes_13b, airoboros_13b, samantha_13b, chronos_13b,
909
+ wizardlm_13b, wizard_vicuna_13b, wizard_coder_15b, vicuna_13b_1_3, openllama_13b, orcamini_13b,
910
+ llama2_13b, nous_hermes_13b_v2, camel20b,
911
+ guanaco_33b, falcon_40b, wizard_falcon_40b, samantha_33b, lazarus_30b, chronos_33b,
912
+ wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b, upstage_llama_30b,
913
+ free_willy2_70b
914
+ ]
915
+ for btn in btns:
916
+ btn.click(
917
+ move_to_second_view,
918
+ btn,
919
+ [
920
+ model_choice_view, model_review_view,
921
+ model_image, model_name, model_params, model_base, model_ckpt,
922
+ model_desc, model_vram, gen_config_path,
923
+ example_showcase1, example_showcase2, example_showcase3, example_showcase4,
924
+ model_thumbnail_tiny, load_mode,
925
+ progress_view
926
+ ]
927
+ )
928
+
929
+ select_model.click(
930
+ move_to_model_select_view,
931
+ None,
932
+ [progress_view0, landing_view, model_choice_view]
933
+ )
934
+
935
+ chosen_model.click(
936
+ use_chosen_model,
937
+ None,
938
+ [progress_view0, landing_view, chat_view, chatbot, chat_state, global_context,
939
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
940
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
941
+ )
942
+
943
+ byom.click(
944
+ move_to_byom_view,
945
+ None,
946
+ [progress_view0, landing_view, byom_input_view, byom_load_mode]
947
+ )
948
+
949
+ byom_back_btn.click(
950
+ move_to_first_view,
951
+ None,
952
+ [landing_view, byom_input_view]
953
+ )
954
+
955
+ byom_confirm_btn.click(
956
+ lambda: "Start downloading/loading the model...", None, txt_view3
957
+ ).then(
958
+ byom_load,
959
+ [byom_base, byom_ckpt, byom_model_cls, byom_tokenizer_cls,
960
+ byom_bos_token_id, byom_eos_token_id, byom_pad_token_id,
961
+ byom_load_mode],
962
+ [progress_view3]
963
+ ).then(
964
+ lambda: "Model is fully loaded...", None, txt_view3
965
+ ).then(
966
+ move_to_third_view,
967
+ None,
968
+ [progress_view3, byom_input_view, chat_view, chatbot, chat_state, global_context,
969
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
970
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
971
+ )
972
+
973
+ prompt_style_selector.change(
974
+ prompt_style_change,
975
+ prompt_style_selector,
976
+ prompt_style_previewer
977
+ )
978
+
979
+ back_to_model_choose_btn.click(
980
+ move_to_first_view,
981
+ None,
982
+ [model_choice_view, model_review_view]
983
+ )
984
+
985
+ confirm_btn.click(
986
+ lambda: "Start downloading/loading the model...", None, txt_view
987
+ ).then(
988
+ download_completed,
989
+ [model_name, model_base, model_ckpt, gen_config_path, gen_config_sum_path, load_mode, model_thumbnail_tiny, force_redownload],
990
+ [progress_view2]
991
+ ).then(
992
+ lambda: "Model is fully loaded...", None, txt_view
993
+ ).then(
994
+ lambda: time.sleep(2), None, None
995
+ ).then(
996
+ move_to_third_view,
997
+ None,
998
+ [progress_view2, model_review_view, chat_view, chatbot, chat_state, global_context,
999
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
1000
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
1001
+ )
1002
+
1003
+ for btn in channel_btns:
1004
+ btn.click(
1005
+ set_chatbot,
1006
+ [btn, local_data, chat_state],
1007
+ [chatbot, idx, example_block, regenerate]
1008
+ ).then(
1009
+ None, btn, None,
1010
+ _js=UPDATE_LEFT_BTNS_STATE
1011
+ )
1012
+
1013
+ for btn in ex_btns:
1014
+ btn.click(
1015
+ set_example,
1016
+ [btn],
1017
+ [instruction_txtbox, example_block]
1018
+ )
1019
+
1020
+ instruction_txtbox.submit(
1021
+ lambda: [
1022
+ gr.update(visible=False),
1023
+ gr.update(interactive=True)
1024
+ ],
1025
+ None,
1026
+ [example_block, regenerate]
1027
+ )
1028
+
1029
+ send_event = instruction_txtbox.submit(
1030
+ central.chat_stream,
1031
+ [idx, local_data, instruction_txtbox, chat_state,
1032
+ global_context, ctx_num_lconv, ctx_sum_prompt,
1033
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
1034
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
1035
+ internet_option, serper_api_key],
1036
+ [instruction_txtbox, chatbot, context_inspector, local_data],
1037
+ )
1038
+
1039
+ instruction_txtbox.submit(
1040
+ None, local_data, None,
1041
+ _js="(v)=>{ setStorage('local_data',v) }"
1042
+ )
1043
+
1044
+ regenerate.click(
1045
+ rollback_last,
1046
+ [idx, local_data, chat_state],
1047
+ [instruction_txtbox, chatbot, local_data, regenerate]
1048
+ ).then(
1049
+ central.chat_stream,
1050
+ [idx, local_data, instruction_txtbox, chat_state,
1051
+ global_context, ctx_num_lconv, ctx_sum_prompt,
1052
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
1053
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
1054
+ internet_option, serper_api_key],
1055
+ [instruction_txtbox, chatbot, context_inspector, local_data],
1056
+ ).then(
1057
+ lambda: gr.update(interactive=True),
1058
+ None,
1059
+ regenerate
1060
+ ).then(
1061
+ None, local_data, None,
1062
+ _js="(v)=>{ setStorage('local_data',v) }"
1063
+ )
1064
+
1065
+ stop.click(
1066
+ None, None, None,
1067
+ cancels=[send_event]
1068
+ )
1069
+
1070
+ clean.click(
1071
+ reset_chat,
1072
+ [idx, local_data, chat_state],
1073
+ [instruction_txtbox, chatbot, local_data, example_block, regenerate]
1074
+ ).then(
1075
+ None, local_data, None,
1076
+ _js="(v)=>{ setStorage('local_data',v) }"
1077
+ )
1078
+
1079
+ chat_back_btn.click(
1080
+ lambda: [gr.update(visible=False), gr.update(visible=True)],
1081
+ None,
1082
+ [chat_view, landing_view]
1083
+ )
1084
+
1085
+ demo.load(
1086
+ None,
1087
+ inputs=None,
1088
+ outputs=[chatbot, local_data],
1089
+ _js=GET_LOCAL_STORAGE,
1090
+ )
1091
+
1092
+ demo.queue().launch(
1093
+ server_port=6006,
1094
+ server_name="0.0.0.0",
1095
+ debug=args.debug,
1096
+ share=args.share,
1097
+ root_path=f"{args.root_path}"
1098
+ )
1099
+
1100
+ if __name__ == "__main__":
1101
+ parser = argparse.ArgumentParser()
1102
+ parser.add_argument('--root-path', default="")
1103
+ parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction)
1104
+ parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction)
1105
+ parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction)
1106
+ parser.add_argument('--serper-api-key', default=None, type=str)
1107
+ args = parser.parse_args()
1108
+
1109
+ gradio_main(args)
assets/guimode_preview.gif ADDED

Git LFS Details

  • SHA256: a8ed57c6d4bca465aaa8490d21ca45b9e4f82c17d36e40e219d7ac0236b7c9e0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
assets/preview.gif ADDED

Git LFS Details

  • SHA256: c7df81c43bfe1327bf222473e06f492b15b3cf64054df27381951ba15a1172ff
  • Pointer size: 132 Bytes
  • Size of remote file: 9.76 MB
assets/preview.png ADDED

Git LFS Details

  • SHA256: 62e092ae338a5970423a3c4fd1e4caf2280d3083cd3a7eb2612d3bf3c5b80c67
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
channels.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 1st Channel
2
+ 2nd Channel
3
+ 3rd Channel
4
+ 4th Channel
5
+ 5th Channel
6
+ 6th Channel
7
+ 7th Channel
8
+ 8th Channel
9
+ 9th Channel
10
+ 10th Channel
chats/__init__.py ADDED
File without changes
chats/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (129 Bytes). View file
 
chats/__pycache__/alpaca.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/alpaca_gpt4.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/alpacoom.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/baize.cpython-39.pyc ADDED
Binary file (1.87 kB). View file
 
chats/__pycache__/central.cpython-39.pyc ADDED
Binary file (5.45 kB). View file
 
chats/__pycache__/custom.cpython-39.pyc ADDED
Binary file (1.88 kB). View file
 
chats/__pycache__/falcon.cpython-39.pyc ADDED
Binary file (2.09 kB). View file
 
chats/__pycache__/flan_alpaca.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/freewilly.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/guanaco.cpython-39.pyc ADDED
Binary file (2.09 kB). View file
 
chats/__pycache__/koalpaca.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/llama2.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/mpt.cpython-39.pyc ADDED
Binary file (2.34 kB). View file
 
chats/__pycache__/os_stablelm.cpython-39.pyc ADDED
Binary file (2.12 kB). View file
 
chats/__pycache__/post.cpython-39.pyc ADDED
Binary file (285 Bytes). View file
 
chats/__pycache__/pre.cpython-39.pyc ADDED
Binary file (2.19 kB). View file
 
chats/__pycache__/redpajama.cpython-39.pyc ADDED
Binary file (2.56 kB). View file
 
chats/__pycache__/stable_vicuna.cpython-39.pyc ADDED
Binary file (2.32 kB). View file
 
chats/__pycache__/stablelm.cpython-39.pyc ADDED
Binary file (2.11 kB). View file
 
chats/__pycache__/starchat.cpython-39.pyc ADDED
Binary file (2.1 kB). View file
 
chats/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.73 kB). View file
 
chats/__pycache__/vicuna.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
chats/__pycache__/wizard_coder.cpython-39.pyc ADDED
Binary file (2.09 kB). View file
 
chats/__pycache__/wizard_falcon.cpython-39.pyc ADDED
Binary file (2.1 kB). View file
 
chats/__pycache__/xgen.cpython-39.pyc ADDED
Binary file (2.4 kB). View file
 
chats/alpaca.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ ppm = post.strip_pong(ppm)
51
+ yield "", ppm.build_uis(), prompt, str(res)
chats/alpaca_gpt4.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ ppm = post.strip_pong(ppm)
51
+ yield "", ppm.build_uis(), prompt, str(res)
chats/alpacoom.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ ppm = post.strip_pong(ppm)
51
+ yield "", ppm.build_uis(), prompt, str(res)
chats/baize.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, internet_search
9
+
10
+ def text_stream(ppmanager, streamer):
11
+ count = 0
12
+
13
+ for new_text in streamer:
14
+ if "[|Human|]" in new_text or \
15
+ "[|AI|]" in new_text:
16
+ break
17
+
18
+ if count == 0:
19
+ ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n")
20
+ count = count + 1
21
+
22
+ ppmanager.append_pong(new_text)
23
+ yield ppmanager, ppmanager.build_uis()
24
+
25
+ yield ppmanager, ppmanager.build_uis()
26
+
27
+ def chat_stream(
28
+ idx, local_data, user_message, state,
29
+ global_context, ctx_num_lconv, ctx_sum_prompt,
30
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
31
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
32
+ internet_option, serper_api_key
33
+ ):
34
+ res = [
35
+ state["ppmanager_type"].from_json(json.dumps(ppm))
36
+ for ppm in local_data
37
+ ]
38
+
39
+ ppm = res[idx]
40
+
41
+ # add_ping returns a prompt structured in Alpaca form
42
+ ppm.add_pingpong(
43
+ PingPong(user_message, "")
44
+ )
45
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
46
+
47
+ #######
48
+ if internet_option:
49
+ search_prompt = None
50
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
51
+ search_prompt = tmp_prompt
52
+ yield "", uis, prompt, str(res)
53
+
54
+ # prepare text generating streamer & start generating
55
+ gen_kwargs, streamer = pre.build(
56
+ search_prompt if internet_option else prompt,
57
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
58
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
59
+ return_token_type_ids=False
60
+ )
61
+ pre.start_gen(gen_kwargs)
62
+
63
+ # handling stream
64
+ for ppmanager, uis in text_stream(ppm, streamer):
65
+ yield "", uis, prompt, str(res)
66
+
67
+ ppm = post.strip_pong(ppm)
68
+ yield "", ppm.build_uis(), prompt, str(res)
chats/central.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chats import stablelm
2
+ from chats import alpaca
3
+ from chats import koalpaca
4
+ from chats import flan_alpaca
5
+ from chats import os_stablelm
6
+ from chats import vicuna
7
+ from chats import stable_vicuna
8
+ from chats import starchat
9
+ from chats import wizard_coder
10
+ from chats import redpajama
11
+ from chats import mpt
12
+ from chats import alpacoom
13
+ from chats import baize
14
+ from chats import guanaco
15
+ from chats import falcon
16
+ from chats import wizard_falcon
17
+ from chats import xgen
18
+ from chats import llama2
19
+ from chats import freewilly
20
+ from chats import custom
21
+
22
+ def chat_stream(
23
+ idx, local_data, user_message, state,
24
+ global_context, ctx_num_lconv, ctx_sum_prompt,
25
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
26
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
27
+ internet_option, serper_api_key
28
+ ):
29
+ model_type = state["model_type"]
30
+
31
+ if internet_option == "on" and serper_api_key.strip() != "":
32
+ internet_option = True
33
+ else:
34
+ internet_option = False
35
+
36
+ if model_type == "custom":
37
+ cs = custom.chat_stream(
38
+ idx, local_data, user_message, state,
39
+ global_context, ctx_num_lconv, ctx_sum_prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
41
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
42
+ internet_option, serper_api_key
43
+ )
44
+
45
+ elif model_type == "free-willy":
46
+ cs = freewilly.chat_stream(
47
+ idx, local_data, user_message, state,
48
+ global_context, ctx_num_lconv, ctx_sum_prompt,
49
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
50
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
51
+ internet_option, serper_api_key
52
+ )
53
+
54
+ elif model_type == "upstage-llama":
55
+ cs = alpaca.chat_stream(
56
+ idx, local_data, user_message, state,
57
+ global_context, ctx_num_lconv, ctx_sum_prompt,
58
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
59
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
60
+ internet_option, serper_api_key
61
+ )
62
+
63
+ elif model_type == "llama2":
64
+ cs = llama2.chat_stream(
65
+ idx, local_data, user_message, state,
66
+ global_context, ctx_num_lconv, ctx_sum_prompt,
67
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
68
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
69
+ internet_option, serper_api_key
70
+ )
71
+
72
+ elif model_type == "xgen":
73
+ cs = xgen.chat_stream(
74
+ idx, local_data, user_message, state,
75
+ global_context, ctx_num_lconv, ctx_sum_prompt,
76
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
77
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
78
+ internet_option, serper_api_key
79
+ )
80
+
81
+ elif model_type == "stablelm":
82
+ cs = stablelm.chat_stream(
83
+ idx, local_data, user_message, state,
84
+ global_context, ctx_num_lconv, ctx_sum_prompt,
85
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
86
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
87
+ internet_option, serper_api_key
88
+ )
89
+
90
+ elif model_type == "falcon":
91
+ cs = falcon.chat_stream(
92
+ idx, local_data, user_message, state,
93
+ global_context, ctx_num_lconv, ctx_sum_prompt,
94
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
95
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
96
+ internet_option, serper_api_key
97
+ )
98
+
99
+ elif model_type == "wizard-falcon":
100
+ cs = wizard_falcon.chat_stream(
101
+ idx, local_data, user_message, state,
102
+ global_context, ctx_num_lconv, ctx_sum_prompt,
103
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
104
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
105
+ internet_option, serper_api_key
106
+ )
107
+
108
+ elif model_type == "baize":
109
+ cs = baize.chat_stream(
110
+ idx, local_data, user_message, state,
111
+ global_context, ctx_num_lconv, ctx_sum_prompt,
112
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
113
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
114
+ internet_option, serper_api_key
115
+ )
116
+
117
+ elif model_type == "alpaca":
118
+ cs = alpaca.chat_stream(
119
+ idx, local_data, user_message, state,
120
+ global_context, ctx_num_lconv, ctx_sum_prompt,
121
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
122
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
123
+ internet_option, serper_api_key
124
+ )
125
+
126
+ elif model_type == "openllama":
127
+ cs = alpaca.chat_stream(
128
+ idx, local_data, user_message, state,
129
+ global_context, ctx_num_lconv, ctx_sum_prompt,
130
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
131
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
132
+ internet_option, serper_api_key
133
+ )
134
+
135
+ elif model_type == "orcamini":
136
+ cs = alpaca.chat_stream(
137
+ idx, local_data, user_message, state,
138
+ global_context, ctx_num_lconv, ctx_sum_prompt,
139
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
140
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
141
+ internet_option, serper_api_key
142
+ )
143
+
144
+ elif model_type == "alpaca-gpt4":
145
+ cs = alpaca.chat_stream(
146
+ idx, local_data, user_message, state,
147
+ global_context, ctx_num_lconv, ctx_sum_prompt,
148
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
149
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
150
+ internet_option, serper_api_key
151
+ )
152
+
153
+ elif model_type == "nous-hermes":
154
+ cs = alpaca.chat_stream(
155
+ idx, local_data, user_message, state,
156
+ global_context, ctx_num_lconv, ctx_sum_prompt,
157
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
158
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
159
+ internet_option, serper_api_key
160
+ )
161
+
162
+ elif model_type == "replit-instruct":
163
+ cs = alpaca.chat_stream(
164
+ idx, local_data, user_message, state,
165
+ global_context, ctx_num_lconv, ctx_sum_prompt,
166
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
167
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
168
+ internet_option, serper_api_key
169
+ )
170
+
171
+ elif model_type == "alpacoom":
172
+ cs = alpacoom.chat_stream(
173
+ idx, local_data, user_message, state,
174
+ global_context, ctx_num_lconv, ctx_sum_prompt,
175
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
176
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
177
+ internet_option, serper_api_key
178
+ )
179
+
180
+ elif model_type == "llama-deus":
181
+ cs = alpaca.chat_stream(
182
+ idx, local_data, user_message, state,
183
+ global_context, ctx_num_lconv, ctx_sum_prompt,
184
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
185
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
186
+ internet_option, serper_api_key
187
+ )
188
+
189
+ elif model_type == "camel":
190
+ cs = alpaca.chat_stream(
191
+ idx, local_data, user_message, state,
192
+ global_context, ctx_num_lconv, ctx_sum_prompt,
193
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
194
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
195
+ internet_option, serper_api_key
196
+ )
197
+
198
+ elif model_type == "koalpaca-polyglot":
199
+ cs = koalpaca.chat_stream(
200
+ idx, local_data, user_message, state,
201
+ global_context, ctx_num_lconv, ctx_sum_prompt,
202
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
203
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
204
+ internet_option, serper_api_key
205
+ )
206
+
207
+ elif model_type == "kullm-polyglot":
208
+ cs = koalpaca.chat_stream(
209
+ idx, local_data, user_message, state,
210
+ global_context, ctx_num_lconv, ctx_sum_prompt,
211
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
212
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
213
+ internet_option, serper_api_key
214
+ )
215
+
216
+ elif model_type == "flan-alpaca":
217
+ cs = flan_alpaca.chat_stream(
218
+ idx, local_data, user_message, state,
219
+ global_context, ctx_num_lconv, ctx_sum_prompt,
220
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
221
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
222
+ internet_option, serper_api_key
223
+ )
224
+
225
+ elif model_type == "os-stablelm":
226
+ cs = os_stablelm.chat_stream(
227
+ idx, local_data, user_message, state,
228
+ global_context, ctx_num_lconv, ctx_sum_prompt,
229
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
230
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
231
+ internet_option, serper_api_key
232
+ )
233
+
234
+ elif model_type == "t5-vicuna":
235
+ cs = vicuna.chat_stream(
236
+ idx, local_data, user_message, state,
237
+ global_context, ctx_num_lconv, ctx_sum_prompt,
238
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
239
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
240
+ internet_option, serper_api_key
241
+ )
242
+
243
+ elif model_type == "stable-vicuna":
244
+ cs = stable_vicuna.chat_stream(
245
+ idx, local_data, user_message, state,
246
+ global_context, ctx_num_lconv, ctx_sum_prompt,
247
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
248
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
249
+ internet_option, serper_api_key
250
+ )
251
+
252
+ elif model_type == "vicuna":
253
+ cs = vicuna.chat_stream(
254
+ idx, local_data, user_message, state,
255
+ global_context, ctx_num_lconv, ctx_sum_prompt,
256
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
257
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
258
+ internet_option, serper_api_key
259
+ )
260
+
261
+ elif model_type == "wizardlm":
262
+ cs = vicuna.chat_stream(
263
+ idx, local_data, user_message, state,
264
+ global_context, ctx_num_lconv, ctx_sum_prompt,
265
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
266
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
267
+ internet_option, serper_api_key
268
+ )
269
+
270
+ elif model_type == "wizard-vicuna":
271
+ cs = vicuna.chat_stream(
272
+ idx, local_data, user_message, state,
273
+ global_context, ctx_num_lconv, ctx_sum_prompt,
274
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
275
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
276
+ internet_option, serper_api_key
277
+ )
278
+
279
+ elif model_type == "airoboros":
280
+ cs = vicuna.chat_stream(
281
+ idx, local_data, user_message, state,
282
+ global_context, ctx_num_lconv, ctx_sum_prompt,
283
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
284
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
285
+ internet_option, serper_api_key
286
+ )
287
+
288
+ elif model_type == "samantha-vicuna":
289
+ cs = vicuna.chat_stream(
290
+ idx, local_data, user_message, state,
291
+ global_context, ctx_num_lconv, ctx_sum_prompt,
292
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
293
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
294
+ internet_option, serper_api_key
295
+ )
296
+
297
+ elif model_type == "evolinstruct-vicuna":
298
+ cs = vicuna.chat_stream(
299
+ idx, local_data, user_message, state,
300
+ global_context, ctx_num_lconv, ctx_sum_prompt,
301
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
302
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
303
+ internet_option, serper_api_key
304
+ )
305
+
306
+ elif model_type == "starchat":
307
+ cs = starchat.chat_stream(
308
+ idx, local_data, user_message, state,
309
+ global_context, ctx_num_lconv, ctx_sum_prompt,
310
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
311
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
312
+ internet_option, serper_api_key
313
+ )
314
+
315
+ elif model_type == "wizard-coder":
316
+ cs = wizard_coder.chat_stream(
317
+ idx, local_data, user_message, state,
318
+ global_context, ctx_num_lconv, ctx_sum_prompt,
319
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
320
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
321
+ internet_option, serper_api_key
322
+ )
323
+
324
+ elif model_type == "mpt":
325
+ cs = mpt.chat_stream(
326
+ idx, local_data, user_message, state,
327
+ global_context, ctx_num_lconv, ctx_sum_prompt,
328
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
329
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
330
+ internet_option, serper_api_key
331
+ )
332
+
333
+ elif model_type == "redpajama":
334
+ cs = redpajama.chat_stream(
335
+ idx, local_data, user_message, state,
336
+ global_context, ctx_num_lconv, ctx_sum_prompt,
337
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
338
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
339
+ internet_option, serper_api_key
340
+ )
341
+
342
+ elif model_type == "redpajama-instruct":
343
+ cs = redpajama.chat_stream(
344
+ idx, local_data, user_message, state,
345
+ global_context, ctx_num_lconv, ctx_sum_prompt,
346
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
347
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
348
+ internet_option, serper_api_key
349
+ )
350
+
351
+ elif model_type == "guanaco":
352
+ cs = guanaco.chat_stream(
353
+ idx, local_data, user_message, state,
354
+ global_context, ctx_num_lconv, ctx_sum_prompt,
355
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
356
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
357
+ internet_option, serper_api_key
358
+ )
359
+
360
+ elif model_type == "lazarus":
361
+ cs = alpaca.chat_stream(
362
+ idx, local_data, user_message, state,
363
+ global_context, ctx_num_lconv, ctx_sum_prompt,
364
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
365
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
366
+ internet_option, serper_api_key
367
+ )
368
+
369
+ elif model_type == "chronos":
370
+ cs = alpaca.chat_stream(
371
+ idx, local_data, user_message, state,
372
+ global_context, ctx_num_lconv, ctx_sum_prompt,
373
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
374
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
375
+ internet_option, serper_api_key
376
+ )
377
+
378
+ for idx, x in enumerate(cs):
379
+ yield x
380
+
chats/custom.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, internet_search
9
+
10
+ def text_stream(ppmanager, streamer):
11
+ count = 0
12
+ thumbnail_tiny = "https://i.ibb.co/f80BpgR/byom.png"
13
+
14
+ for new_text in streamer:
15
+ if count == 0:
16
+ ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n")
17
+ count = count + 1
18
+
19
+ ppmanager.append_pong(new_text)
20
+ yield ppmanager, ppmanager.build_uis()
21
+
22
+ yield ppmanager, ppmanager.build_uis()
23
+
24
+ def chat_stream(
25
+ idx, local_data, user_message, state,
26
+ global_context, ctx_num_lconv, ctx_sum_prompt,
27
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
28
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
29
+ internet_option, serper_api_key
30
+ ):
31
+ res = [
32
+ state["ppmanager_type"].from_json(json.dumps(ppm))
33
+ for ppm in local_data
34
+ ]
35
+
36
+ ppm = res[idx]
37
+
38
+ # add_ping returns a prompt structured in Alpaca form
39
+ ppm.add_pingpong(
40
+ PingPong(user_message, "")
41
+ )
42
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
43
+
44
+ #######
45
+ if internet_option:
46
+ search_prompt = None
47
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
48
+ search_prompt = tmp_prompt
49
+ yield "", uis, prompt, str(res)
50
+
51
+ # prepare text generating streamer & start generating
52
+ gen_kwargs, streamer = pre.build(
53
+ search_prompt if internet_option else prompt,
54
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
55
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
56
+ return_token_type_ids=False
57
+ )
58
+ pre.start_gen(gen_kwargs)
59
+
60
+ # handling stream
61
+ for ppmanager, uis in text_stream(ppm, streamer):
62
+ yield "", uis, prompt, str(res)
63
+
64
+ ppm = post.strip_pong(ppm)
65
+ yield "", ppm.build_uis(), prompt, str(res)
chats/falcon.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import StoppingCriteria, StoppingCriteriaList
3
+
4
+ import copy
5
+ import json
6
+ import global_vars
7
+ from chats import pre, post
8
+ from pingpong import PingPong
9
+ from gens.batch_gen import get_output_batch
10
+
11
+ from chats.utils import build_prompts, text_stream, internet_search
12
+
13
+ class StopOnTokens(StoppingCriteria):
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15
+ stop_ids = [11]
16
+ for stop_id in stop_ids:
17
+ if input_ids[0][-1] == stop_id:
18
+ return True
19
+ return False
20
+
21
+ def chat_stream(
22
+ idx, local_data, user_message, state,
23
+ global_context, ctx_num_lconv, ctx_sum_prompt,
24
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
25
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
26
+ internet_option, serper_api_key
27
+ ):
28
+ res = [
29
+ state["ppmanager_type"].from_json(json.dumps(ppm))
30
+ for ppm in local_data
31
+ ]
32
+
33
+ ppm = res[idx]
34
+
35
+ # add_ping returns a prompt structured in Alpaca form
36
+ ppm.add_pingpong(
37
+ PingPong(user_message, "")
38
+ )
39
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
40
+
41
+ #######
42
+ if internet_option:
43
+ search_prompt = None
44
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
45
+ search_prompt = tmp_prompt
46
+ yield "", uis, prompt, str(res)
47
+
48
+ # prepare text generating streamer & start generating
49
+ gen_kwargs, streamer = pre.build(
50
+ search_prompt if internet_option else prompt,
51
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
52
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
53
+ return_token_type_ids=False
54
+ )
55
+ pre.start_gen(gen_kwargs)
56
+
57
+ # handling stream
58
+ for ppmanager, uis in text_stream(ppm, streamer):
59
+ yield "", uis, prompt, str(res)
60
+
61
+ ppm = post.strip_pong(ppm)
62
+ yield "", ppm.build_uis(), prompt, str(res)
chats/flan_alpaca.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ ppm = post.strip_pong(ppm)
51
+ yield "", ppm.build_uis(), prompt, str(res)
chats/freewilly.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ # output = f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n"
51
+
52
+ # inputs = global_vars.tokenizer(
53
+ # prompt, return_tensors="pt"
54
+ # ).to(global_vars.device)
55
+
56
+ # output = output + global_vars.model.generate(
57
+ # **inputs,
58
+ # temperature=res_temp,
59
+ # do_sample=res_sample,
60
+ # top_p=res_topp,
61
+ # top_k=res_topk,
62
+ # repetition_penalty=res_rpen,
63
+ # num_beams=res_beams,
64
+ # use_cache=res_cache,
65
+ # eos_token_id=res_eosid,
66
+ # pad_token_id=res_padid,
67
+ # max_new_tokens=res_mnts
68
+ # )
69
+
70
+ # ppm.replace_last_pong(output)
71
+ # yield "", ppm.build_uis(), prompt, str(res)
72
+
73
+ ppm = post.strip_pong(ppm)
74
+ yield "", ppm.build_uis(), prompt, str(res)
chats/guanaco.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import StoppingCriteria, StoppingCriteriaList
3
+
4
+ import copy
5
+ import json
6
+ import global_vars
7
+ from chats import pre, post
8
+ from pingpong import PingPong
9
+ from gens.batch_gen import get_output_batch
10
+
11
+ from chats.utils import build_prompts, text_stream, internet_search
12
+
13
+ class StopOnTokens(StoppingCriteria):
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15
+ stop_token_ids = [0]
16
+
17
+ for stop_id in stop_token_ids:
18
+ if input_ids[0][-1] == stop_id:
19
+ return True
20
+ return False
21
+
22
+ def chat_stream(
23
+ idx, local_data, user_message, state,
24
+ global_context, ctx_num_lconv, ctx_sum_prompt,
25
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
26
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
27
+ internet_option, serper_api_key
28
+ ):
29
+ res = [
30
+ state["ppmanager_type"].from_json(json.dumps(ppm))
31
+ for ppm in local_data
32
+ ]
33
+
34
+ ppm = res[idx]
35
+
36
+ # add_ping returns a prompt structured in Alpaca form
37
+ ppm.add_pingpong(
38
+ PingPong(user_message, "")
39
+ )
40
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
41
+
42
+ #######
43
+ if internet_option:
44
+ search_prompt = None
45
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
46
+ search_prompt = tmp_prompt
47
+ yield "", uis, prompt, str(res)
48
+
49
+ # prepare text generating streamer & start generating
50
+ gen_kwargs, streamer = pre.build(
51
+ search_prompt if internet_option else prompt,
52
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
53
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
54
+ return_token_type_ids=False
55
+ )
56
+ pre.start_gen(gen_kwargs)
57
+
58
+ # handling stream
59
+ for ppmanager, uis in text_stream(ppm, streamer):
60
+ yield "", uis, prompt, str(res)
61
+
62
+ ppm = post.strip_pong(ppm)
63
+ yield "", ppm.build_uis(), prompt, str(res)
chats/koalpaca.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import global_vars
4
+ from chats import pre, post
5
+ from pingpong import PingPong
6
+ from gens.batch_gen import get_output_batch
7
+
8
+ from chats.utils import build_prompts, text_stream, internet_search
9
+
10
+ def chat_stream(
11
+ idx, local_data, user_message, state,
12
+ global_context, ctx_num_lconv, ctx_sum_prompt,
13
+ res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
14
+ sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
15
+ internet_option, serper_api_key
16
+ ):
17
+ res = [
18
+ state["ppmanager_type"].from_json(json.dumps(ppm))
19
+ for ppm in local_data
20
+ ]
21
+
22
+ ppm = res[idx]
23
+
24
+ # add_ping returns a prompt structured in Alpaca form
25
+ ppm.add_pingpong(
26
+ PingPong(user_message, "")
27
+ )
28
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
29
+
30
+ #######
31
+ if internet_option:
32
+ search_prompt = None
33
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
34
+ search_prompt = tmp_prompt
35
+ yield "", uis, prompt, str(res)
36
+
37
+ # prepare text generating streamer & start generating
38
+ gen_kwargs, streamer = pre.build(
39
+ search_prompt if internet_option else prompt,
40
+ res_temp, res_topp, res_topk, res_rpen, res_mnts,
41
+ res_beams, res_cache, res_sample, res_eosid, res_padid,
42
+ return_token_type_ids=False
43
+ )
44
+ pre.start_gen(gen_kwargs)
45
+
46
+ # handling stream
47
+ for ppmanager, uis in text_stream(ppm, streamer):
48
+ yield "", uis, prompt, str(res)
49
+
50
+ ppm = post.strip_pong(ppm)
51
+ yield "", ppm.build_uis(), prompt, str(res)