Ascol57 commited on
Commit
d96a239
·
verified ·
1 Parent(s): ff363c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +566 -62
app.py CHANGED
@@ -1,63 +1,567 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import re
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+
7
  import gradio as gr
8
+ import torch
9
+
10
+ import modules.shared as shared
11
+ from modules import chat, ui as ui_module
12
+ from modules.utils import gradio
13
+ from modules.text_generation import generate_reply_HF, generate_reply_custom
14
+ from .llm_web_search import get_webpage_content, langchain_search_duckduckgo, langchain_search_searxng, Generator
15
+ from .langchain_websearch import LangchainCompressor
16
+
17
+
18
+ params = {
19
+ "display_name": "LLM Web Search",
20
+ "is_tab": True,
21
+ "enable": True,
22
+ "search results per query": 5,
23
+ "langchain similarity score threshold": 0.5,
24
+ "instant answers": True,
25
+ "regular search results": True,
26
+ "search command regex": "",
27
+ "default search command regex": r"Search_web\(\"(.*)\"\)",
28
+ "open url command regex": "",
29
+ "default open url command regex": r"Open_url\(\"(.*)\"\)",
30
+ "display search results in chat": True,
31
+ "display extracted URL content in chat": True,
32
+ "searxng url": "",
33
+ "cpu only": True,
34
+ "chunk size": 500,
35
+ "duckduckgo results per query": 10,
36
+ "append current datetime": False,
37
+ "default system prompt filename": None,
38
+ "force search prefix": "Search_web",
39
+ "ensemble weighting": 0.5,
40
+ "keyword retriever": "bm25",
41
+ "splade batch size": 2,
42
+ "chunking method": "character-based",
43
+ "chunker breakpoint_threshold_amount": 30
44
+ }
45
+ custom_system_message_filename = None
46
+ extension_path = os.path.dirname(os.path.abspath(__file__))
47
+ langchain_compressor = None
48
+ update_history = None
49
+ force_search = False
50
+
51
+
52
+ def setup():
53
+ """
54
+ Is executed when the extension gets imported.
55
+ :return:
56
+ """
57
+ global params
58
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
59
+ os.environ["QDRANT__TELEMETRY_DISABLED"] = "true"
60
+
61
+ try:
62
+ with open(os.path.join(extension_path, "settings.json"), "r") as f:
63
+ saved_params = json.load(f)
64
+ params.update(saved_params)
65
+ save_settings() # add keys of newly added feature to settings.json
66
+ except FileNotFoundError:
67
+ save_settings()
68
+
69
+ if not os.path.exists(os.path.join(extension_path, "system_prompts")):
70
+ os.makedirs(os.path.join(extension_path, "system_prompts"))
71
+
72
+ toggle_extension(params["enable"])
73
+
74
+
75
+ def save_settings():
76
+ global params
77
+ with open(os.path.join(extension_path, "settings.json"), "w") as f:
78
+ json.dump(params, f, indent=4)
79
+ current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
80
+ return gr.HTML(f'<span style="color:lawngreen"> Settings were saved at {current_datetime}</span>',
81
+ visible=True)
82
+
83
+
84
+ def toggle_extension(_enable: bool):
85
+ global langchain_compressor, custom_system_message_filename
86
+ if _enable:
87
+ langchain_compressor = LangchainCompressor(device="cpu" if params["cpu only"] else "cuda",
88
+ keyword_retriever=params["keyword retriever"],
89
+ model_cache_dir=os.path.join(extension_path, "hf_models"))
90
+ compressor_model = langchain_compressor.embeddings.client
91
+ compressor_model.to(compressor_model._target_device)
92
+ custom_system_message_filename = params.get("default system prompt filename")
93
+ else:
94
+ if not params["cpu only"] and 'langchain_compressor' in globals(): # free some VRAM
95
+ model_attrs = ["embeddings", "splade_doc_model", "splade_query_model"]
96
+ for model_attr in model_attrs:
97
+ if hasattr(langchain_compressor, model_attr):
98
+ model = getattr(langchain_compressor, model_attr)
99
+ if hasattr(model, "client"):
100
+ model.client.to("cpu")
101
+ del model.client
102
+ else:
103
+ if hasattr(model, "to"):
104
+ model.to("cpu")
105
+ del model
106
+ torch.cuda.empty_cache()
107
+ params.update({"enable": _enable})
108
+ return _enable
109
+
110
+
111
+ def get_available_system_prompts():
112
+ try:
113
+ return ["None"] + sorted(os.listdir(os.path.join(extension_path, "system_prompts")))
114
+ except FileNotFoundError:
115
+ return ["None"]
116
+
117
+
118
+ def load_system_prompt(filename: str or None):
119
+ global custom_system_message_filename
120
+ if not filename:
121
+ return
122
+ if filename == "None" or filename == "Select custom system message to load...":
123
+ custom_system_message_filename = None
124
+ return ""
125
+ with open(os.path.join(extension_path, "system_prompts", filename), "r") as f:
126
+ prompt_str = f.read()
127
+
128
+ if params["append current datetime"]:
129
+ prompt_str += f"\nDate and time of conversation: {datetime.now().strftime('%A %d %B %Y %H:%M')}"
130
+
131
+ shared.settings['custom_system_message'] = prompt_str
132
+ custom_system_message_filename = filename
133
+ return prompt_str
134
+
135
+
136
+ def save_system_prompt(filename, prompt):
137
+ if not filename:
138
+ return
139
+
140
+ with open(os.path.join(extension_path, "system_prompts", filename), "w") as f:
141
+ f.write(prompt)
142
+
143
+ return gr.HTML(f'<span style="color:lawngreen"> Saved successfully</span>',
144
+ visible=True)
145
+
146
+
147
+ def check_file_exists(filename):
148
+ if filename == "":
149
+ return gr.HTML("", visible=False)
150
+ if os.path.exists(os.path.join(extension_path, "system_prompts", filename)):
151
+ return gr.HTML(f'<span style="color:orange"> Warning: Filename already exists</span>', visible=True)
152
+ return gr.HTML("", visible=False)
153
+
154
+
155
+ def timeout_save_message():
156
+ time.sleep(2)
157
+ return gr.HTML("", visible=False)
158
+
159
+
160
+ def deactivate_system_prompt():
161
+ shared.settings['custom_system_message'] = None
162
+ return "None"
163
+
164
+
165
+ def toggle_forced_search(value):
166
+ global force_search
167
+ force_search = value
168
+
169
+
170
+ def ui():
171
+ """
172
+ Creates custom gradio elements when the UI is launched.
173
+ :return:
174
+ """
175
+ # Inject custom system message into the main textbox if a default one is set
176
+ shared.gradio['custom_system_message'].value = load_system_prompt(custom_system_message_filename)
177
+
178
+ def update_result_type_setting(choice: str):
179
+ if choice == "Instant answers":
180
+ params.update({"instant answers": True})
181
+ params.update({"regular search results": False})
182
+ elif choice == "Regular results":
183
+ params.update({"instant answers": False})
184
+ params.update({"regular search results": True})
185
+ elif choice == "Regular results and instant answers":
186
+ params.update({"instant answers": True})
187
+ params.update({"regular search results": True})
188
+
189
+ def update_regex_setting(input_str: str, setting_key: str, error_html_element: gr.component):
190
+ if input_str == "":
191
+ params.update({setting_key: params[f"default {setting_key}"]})
192
+ return {error_html_element: gr.HTML("", visible=False)}
193
+ try:
194
+ compiled = re.compile(input_str)
195
+ if compiled.groups > 1:
196
+ raise re.error(f"Only 1 capturing group allowed in regex, but there are {compiled.groups}.")
197
+ params.update({setting_key: input_str})
198
+ return {error_html_element: gr.HTML("", visible=False)}
199
+ except re.error as e:
200
+ return {error_html_element: gr.HTML(f'<span style="color:red"> Invalid regex. {str(e).capitalize()}</span>',
201
+ visible=True)}
202
+
203
+ def update_default_custom_system_message(check: bool):
204
+ if check:
205
+ params.update({"default system prompt filename": custom_system_message_filename})
206
+ else:
207
+ params.update({"default system prompt filename": None})
208
+
209
+ with gr.Row():
210
+ enable = gr.Checkbox(value=lambda: params['enable'], label='Enable LLM web search')
211
+ use_cpu_only = gr.Checkbox(value=lambda: params['cpu only'],
212
+ label='Run extension on CPU only '
213
+ '(Save settings and restart for the change to take effect)')
214
+ with gr.Column():
215
+ save_settings_btn = gr.Button("Save settings")
216
+ saved_success_elem = gr.HTML("", visible=False)
217
+
218
+ with gr.Row():
219
+ result_radio = gr.Radio(
220
+ ["Regular results", "Regular results and instant answers"],
221
+ label="What kind of search results should be returned?",
222
+ value=lambda: "Regular results and instant answers" if
223
+ (params["regular search results"] and params["instant answers"]) else "Regular results"
224
+ )
225
+ with gr.Column():
226
+ search_command_regex = gr.Textbox(label="Search command regex string",
227
+ placeholder=params["default search command regex"],
228
+ value=lambda: params["search command regex"])
229
+ search_command_regex_error_label = gr.HTML("", visible=False)
230
+
231
+ with gr.Column():
232
+ open_url_command_regex = gr.Textbox(label="Open URL command regex string",
233
+ placeholder=params["default open url command regex"],
234
+ value=lambda: params["open url command regex"])
235
+ open_url_command_regex_error_label = gr.HTML("", visible=False)
236
+
237
+ with gr.Column():
238
+ show_results = gr.Checkbox(value=lambda: params['display search results in chat'],
239
+ label='Display search results in chat')
240
+ show_url_content = gr.Checkbox(value=lambda: params['display extracted URL content in chat'],
241
+ label='Display extracted URL content in chat')
242
+ gr.Markdown(value='---')
243
+ with gr.Row():
244
+ with gr.Column():
245
+ gr.Markdown(value='#### Load custom system message\n'
246
+ 'Select a saved custom system message from within the system_prompts folder or "None" '
247
+ 'to clear the selection')
248
+ system_prompt = gr.Dropdown(
249
+ choices=get_available_system_prompts(), label="Select custom system message",
250
+ value=lambda: 'Select custom system message to load...' if custom_system_message_filename is None else
251
+ custom_system_message_filename, elem_classes='slim-dropdown')
252
+ with gr.Row():
253
+ set_system_message_as_default = gr.Checkbox(
254
+ value=lambda: custom_system_message_filename == params["default system prompt filename"],
255
+ label='Set this custom system message as the default')
256
+ refresh_button = ui_module.create_refresh_button(system_prompt, lambda: None,
257
+ lambda: {'choices': get_available_system_prompts()},
258
+ 'refresh-button', interactive=True)
259
+ refresh_button.elem_id = "custom-sysprompt-refresh"
260
+ delete_button = gr.Button('🗑️', elem_classes='refresh-button', interactive=True)
261
+ append_datetime = gr.Checkbox(value=lambda: params['append current datetime'],
262
+ label='Append current date and time when loading custom system message')
263
+ with gr.Column():
264
+ gr.Markdown(value='#### Create custom system message')
265
+ system_prompt_text = gr.Textbox(label="Custom system message", lines=3,
266
+ value=lambda: load_system_prompt(custom_system_message_filename))
267
+ sys_prompt_filename = gr.Text(label="Filename")
268
+ sys_prompt_save_button = gr.Button("Save Custom system message")
269
+ system_prompt_saved_success_elem = gr.HTML("", visible=False)
270
+
271
+ gr.Markdown(value='---')
272
+ with gr.Accordion("Advanced settings", open=False):
273
+ ensemble_weighting = gr.Slider(minimum=0, maximum=1, step=0.05, value=lambda: params["ensemble weighting"],
274
+ label="Ensemble Weighting", info="Smaller values = More keyword oriented, "
275
+ "Larger values = More focus on semantic similarity")
276
+ with gr.Row():
277
+ keyword_retriever = gr.Radio([("Okapi BM25", "bm25"),("SPLADE", "splade")], label="Sparse keyword retriever",
278
+ info="For change to take effect, toggle the extension off and on again",
279
+ value=lambda: params["keyword retriever"])
280
+ splade_batch_size = gr.Slider(minimum=2, maximum=256, step=2, value=lambda: params["splade batch size"],
281
+ label="SPLADE batch size",
282
+ info="Smaller values = Slower retrieval (but lower VRAM usage), "
283
+ "Larger values = Faster retrieval (but higher VRAM usage). "
284
+ "A good trade-off seems to be setting it = 8",
285
+ precision=0)
286
+ with gr.Row():
287
+ chunker = gr.Radio([("Character-based", "character-based"),
288
+ ("Semantic", "semantic")], label="Chunking method",
289
+ value=lambda: params["chunking method"])
290
+ chunker_breakpoint_threshold_amount = gr.Slider(minimum=1, maximum=100, step=1,
291
+ value=lambda: params["chunker breakpoint_threshold_amount"],
292
+ label="Semantic chunking: sentence split threshold (%)",
293
+ info="Defines how different two consecutive sentences have"
294
+ " to be for them to be split into separate chunks",
295
+ precision=0)
296
+ gr.Markdown("**Note: Changing the following might result in DuckDuckGo rate limiting or the LM being overwhelmed**")
297
+ num_search_results = gr.Number(label="Max. search results to return per query", minimum=1, maximum=100,
298
+ value=lambda: params["search results per query"], precision=0)
299
+ num_process_search_results = gr.Number(label="Number of search results to process per query", minimum=1,
300
+ maximum=100, value=lambda: params["duckduckgo results per query"],
301
+ precision=0)
302
+ langchain_similarity_threshold = gr.Number(label="Langchain Similarity Score Threshold", minimum=0., maximum=1.,
303
+ value=lambda: params["langchain similarity score threshold"])
304
+ chunk_size = gr.Number(label="Max. chunk size", info="The maximal size of the individual chunks that each webpage will"
305
+ " be split into, in characters", minimum=2, maximum=10000,
306
+ value=lambda: params["chunk size"], precision=0)
307
+
308
+ with gr.Row():
309
+ searxng_url = gr.Textbox(label="SearXNG URL",
310
+ value=lambda: params["searxng url"])
311
+
312
+ # Event functions to update the parameters in the backend
313
+ enable.input(toggle_extension, enable, enable)
314
+ use_cpu_only.change(lambda x: params.update({"cpu only": x}), use_cpu_only, None)
315
+ save_settings_btn.click(save_settings, None, [saved_success_elem])
316
+ ensemble_weighting.change(lambda x: params.update({"ensemble weighting": x}), ensemble_weighting, None)
317
+ keyword_retriever.change(lambda x: params.update({"keyword retriever": x}), keyword_retriever, None)
318
+ splade_batch_size.change(lambda x: params.update({"splade batch size": x}), splade_batch_size, None)
319
+ chunker.change(lambda x: params.update({"chunking method": x}), chunker, None)
320
+ chunker_breakpoint_threshold_amount.change(lambda x: params.update({"chunker breakpoint_threshold_amount": x}),
321
+ chunker_breakpoint_threshold_amount, None)
322
+ num_search_results.change(lambda x: params.update({"search results per query": x}), num_search_results, None)
323
+ num_process_search_results.change(lambda x: params.update({"duckduckgo results per query": x}),
324
+ num_process_search_results, None)
325
+ langchain_similarity_threshold.change(lambda x: params.update({"langchain similarity score threshold": x}),
326
+ langchain_similarity_threshold, None)
327
+ chunk_size.change(lambda x: params.update({"chunk size": x}), chunk_size, None)
328
+ result_radio.change(update_result_type_setting, result_radio, None)
329
+
330
+ search_command_regex.change(lambda x: update_regex_setting(x, "search command regex",
331
+ search_command_regex_error_label),
332
+ search_command_regex, search_command_regex_error_label, show_progress="hidden")
333
+
334
+ open_url_command_regex.change(lambda x: update_regex_setting(x, "open url command regex",
335
+ open_url_command_regex_error_label),
336
+ open_url_command_regex, open_url_command_regex_error_label, show_progress="hidden")
337
+
338
+ show_results.change(lambda x: params.update({"display search results in chat": x}), show_results, None)
339
+ show_url_content.change(lambda x: params.update({"display extracted URL content in chat": x}), show_url_content,
340
+ None)
341
+ searxng_url.change(lambda x: params.update({"searxng url": x}), searxng_url, None)
342
+
343
+ delete_button.click(
344
+ lambda x: x, system_prompt, gradio('delete_filename')).then(
345
+ lambda: os.path.join(extension_path, "system_prompts", ""), None, gradio('delete_root')).then(
346
+ lambda: gr.update(visible=True), None, gradio('file_deleter'))
347
+ shared.gradio['delete_confirm'].click(
348
+ lambda: "None", None, system_prompt).then(
349
+ None, None, None, _js="() => { document.getElementById('custom-sysprompt-refresh').click() }")
350
+ system_prompt.change(load_system_prompt, system_prompt, shared.gradio['custom_system_message'])
351
+ system_prompt.change(load_system_prompt, system_prompt, system_prompt_text)
352
+ # restore checked state if chosen system prompt matches the default
353
+ system_prompt.change(lambda x: x == params["default system prompt filename"], system_prompt,
354
+ set_system_message_as_default)
355
+ sys_prompt_filename.change(check_file_exists, sys_prompt_filename, system_prompt_saved_success_elem)
356
+ sys_prompt_save_button.click(save_system_prompt, [sys_prompt_filename, system_prompt_text],
357
+ system_prompt_saved_success_elem,
358
+ show_progress="hidden").then(timeout_save_message,
359
+ None,
360
+ system_prompt_saved_success_elem,
361
+ _js="() => { document.getElementById('custom-sysprompt-refresh').click() }",
362
+ show_progress="hidden").then(lambda: "", None,
363
+ sys_prompt_filename,
364
+ show_progress="hidden")
365
+ append_datetime.change(lambda x: params.update({"append current datetime": x}), append_datetime, None)
366
+ # '.input' = only triggers when user changes the value of the component, not a function
367
+ set_system_message_as_default.input(update_default_custom_system_message, set_system_message_as_default, None)
368
+
369
+ # A dummy checkbox to enable the actual "Force web search" checkbox to trigger a gradio event
370
+ force_search_checkbox = gr.Checkbox(value=False, visible=False, elem_id="Force-search-checkbox")
371
+ force_search_checkbox.change(toggle_forced_search, force_search_checkbox, None)
372
+
373
+
374
+ def custom_generate_reply(question, original_question, seed, state, stopping_strings, is_chat):
375
+ """
376
+ Overrides the main text generation function.
377
+ :return:
378
+ """
379
+ global update_history, langchain_compressor
380
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model',
381
+ 'CtransformersModel']:
382
+ generate_func = generate_reply_custom
383
+ else:
384
+ generate_func = generate_reply_HF
385
+
386
+ if not params['enable']:
387
+ for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
388
+ yield reply
389
+ return
390
+
391
+ web_search = False
392
+ read_webpage = False
393
+ max_search_results = int(params["search results per query"])
394
+ instant_answers = params["instant answers"]
395
+ # regular_search_results = params["regular search results"]
396
+
397
+ langchain_compressor.num_results = int(params["duckduckgo results per query"])
398
+ langchain_compressor.similarity_threshold = params["langchain similarity score threshold"]
399
+ langchain_compressor.chunk_size = params["chunk size"]
400
+ langchain_compressor.ensemble_weighting = params["ensemble weighting"]
401
+ langchain_compressor.splade_batch_size = params["splade batch size"]
402
+ langchain_compressor.chunking_method = params["chunking method"]
403
+ langchain_compressor.chunker_breakpoint_threshold_amount = params["chunker breakpoint_threshold_amount"]
404
+
405
+ search_command_regex = params["search command regex"]
406
+ open_url_command_regex = params["open url command regex"]
407
+ searxng_url = params["searxng url"]
408
+ display_search_results = params["display search results in chat"]
409
+ display_webpage_content = params["display extracted URL content in chat"]
410
+
411
+ if search_command_regex == "":
412
+ search_command_regex = params["default search command regex"]
413
+ if open_url_command_regex == "":
414
+ open_url_command_regex = params["default open url command regex"]
415
+
416
+ compiled_search_command_regex = re.compile(search_command_regex)
417
+ compiled_open_url_command_regex = re.compile(open_url_command_regex)
418
+
419
+ if force_search:
420
+ question += f" {params['force search prefix']}"
421
+
422
+ reply = None
423
+ for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
424
+
425
+ if force_search:
426
+ reply = params["force search prefix"] + reply
427
+
428
+ search_re_match = compiled_search_command_regex.search(reply)
429
+ if search_re_match is not None:
430
+ yield reply
431
+ original_model_reply = reply
432
+ web_search = True
433
+ search_term = search_re_match.group(1)
434
+ print(f"LLM_Web_search | Searching for {search_term}...")
435
+ reply += "\n```plaintext"
436
+ reply += "\nSearch tool:\n"
437
+ if searxng_url == "":
438
+ search_generator = Generator(langchain_search_duckduckgo(search_term,
439
+ langchain_compressor,
440
+ max_search_results,
441
+ instant_answers))
442
+ else:
443
+ search_generator = Generator(langchain_search_searxng(search_term,
444
+ searxng_url,
445
+ langchain_compressor,
446
+ max_search_results))
447
+ try:
448
+ for status_message in search_generator:
449
+ yield original_model_reply + f"\n*{status_message}*"
450
+ search_results = search_generator.value
451
+ except Exception as exc:
452
+ exception_message = str(exc)
453
+ reply += f"The search tool encountered an error: {exception_message}"
454
+ print(f'LLM_Web_search | {search_term} generated an exception: {exception_message}')
455
+ else:
456
+ if search_results != "":
457
+ reply += search_results
458
+ else:
459
+ reply += f"\nThe search tool did not return any results."
460
+ reply += "```"
461
+ if display_search_results:
462
+ yield reply
463
+ break
464
+
465
+ open_url_re_match = compiled_open_url_command_regex.search(reply)
466
+ if open_url_re_match is not None:
467
+ yield reply
468
+ original_model_reply = reply
469
+ read_webpage = True
470
+ url = open_url_re_match.group(1)
471
+ print(f"LLM_Web_search | Reading {url}...")
472
+ reply += "\n```plaintext"
473
+ reply += "\nURL opener tool:\n"
474
+ try:
475
+ webpage_content = get_webpage_content(url)
476
+ except Exception as exc:
477
+ reply += f"Couldn't open {url}. Error message: {str(exc)}"
478
+ print(f'LLM_Web_search | {url} generated an exception: {str(exc)}')
479
+ else:
480
+ reply += f"\nText content of {url}:\n"
481
+ reply += webpage_content
482
+ reply += "```\n"
483
+ if display_webpage_content:
484
+ yield reply
485
+ break
486
+ yield reply
487
+
488
+ if web_search or read_webpage:
489
+ display_results = web_search and display_search_results or read_webpage and display_webpage_content
490
+ # Add results to context and continue model output
491
+ new_question = chat.generate_chat_prompt(f"{question}{reply}", state)
492
+ new_reply = ""
493
+ for new_reply in generate_func(new_question, new_question, seed, state,
494
+ stopping_strings, is_chat=is_chat):
495
+ if display_results:
496
+ yield f"{reply}\n{new_reply}"
497
+ else:
498
+ yield f"{original_model_reply}\n{new_reply}"
499
+
500
+ if not display_results:
501
+ update_history = [state["textbox"], f"{reply}\n{new_reply}"]
502
+
503
+
504
+ def output_modifier(string, state, is_chat=False):
505
+ """
506
+ Modifies the output string before it is presented in the UI. In chat mode,
507
+ it is applied to the bot's reply. Otherwise, it is applied to the entire
508
+ output.
509
+ :param string:
510
+ :param state:
511
+ :param is_chat:
512
+ :return:
513
+ """
514
+ return string
515
+
516
+
517
+ def custom_css():
518
+ """
519
+ Returns custom CSS as a string. It is applied whenever the web UI is loaded.
520
+ :return:
521
+ """
522
+ return ''
523
+
524
+
525
+ def custom_js():
526
+ """
527
+ Returns custom javascript as a string. It is applied whenever the web UI is
528
+ loaded.
529
+ :return:
530
+ """
531
+ with open(os.path.join(extension_path, "script.js"), "r") as f:
532
+ return f.read()
533
+
534
+
535
+ def chat_input_modifier(text, visible_text, state):
536
+ """
537
+ Modifies both the visible and internal inputs in chat mode. Can be used to
538
+ hijack the chat input with custom content.
539
+ :param text:
540
+ :param visible_text:
541
+ :param state:
542
+ :return:
543
+ """
544
+ return text, visible_text
545
+
546
+
547
+ def state_modifier(state):
548
+ """
549
+ Modifies the dictionary containing the UI input parameters before it is
550
+ used by the text generation functions.
551
+ :param state:
552
+ :return:
553
+ """
554
+ return state
555
+
556
+
557
+ def history_modifier(history):
558
+ """
559
+ Modifies the chat history before the text generation in chat mode begins.
560
+ :param history:
561
+ :return:
562
+ """
563
+ global update_history
564
+ if update_history:
565
+ history["internal"].append(update_history)
566
+ update_history = None
567
+ return history