Spaces:
Running
Running
File size: 5,159 Bytes
a796108 06665fc 9061790 c6f6149 d5a4cb4 06665fc a796108 c6f6149 9061790 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 fab8405 c6f6149 a796108 c6f6149 9061790 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 a796108 c6f6149 a796108 9061790 c6f6149 9061790 c6f6149 9061790 c6f6149 d5a4cb4 c6f6149 9061790 c6f6149 9061790 c6f6149 9061790 c6f6149 9061790 c6f6149 9061790 c6f6149 9061790 06665fc c6f6149 06665fc c6f6149 06665fc c6f6149 06665fc c6f6149 06665fc c6f6149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import streamlit as st
import json
import textwrap
from typing import Dict, Any, List
from sql_formatter.core import format_sql
from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThought,
StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from streamlit.delta_generator import DeltaGenerator
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Working...")
self.tokens_stream = ""
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
self.progress_bar.progress(value=0.2, text="Asking LLM...")
def on_chain_end(self, outputs, **kwargs) -> None:
self.progress_bar.progress(value=0.6, text="Searching in DB...")
if "repr" in outputs:
st.markdown("### Generated Filter")
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
pass
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Searching DB...")
self.status_bar = st.empty()
self.prog_value = 0.0
self.prog_map = {
"langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
"langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
"langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
}
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_text(self, text: str, **kwargs) -> None:
pass
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
if cid != "langchain.chains.llm.LLMChain":
self.progress_bar.progress(
value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
)
self.prog_value = self.prog_map[cid]
else:
self.prog_value += 0.1
self.progress_bar.progress(
value=self.prog_value, text=f"Running Chain `{cid}`..."
)
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.2
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
pass
def on_llm_end(
self,
response: LLMResult,
*args,
**kwargs,
):
text = response.generations[0][0].text
if text.replace(" ", "").upper().startswith("SELECT"):
st.write("We generated Vector SQL for you:")
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
print(f"Vector SQL: {text}")
self.prog_value += self.prog_interval
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
cid = ".".join(serialized["id"])
self.prog_value += self.prog_interval
self.progress_bar.progress(
value=self.prog_value, text=f"Running Chain `{cid}`..."
)
def on_chain_end(self, outputs, **kwargs) -> None:
pass
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
def __init__(self) -> None:
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
self.status_bar = st.empty()
self.prog_value = 0
self.prog_interval = 0.1
class LLMThoughtWithKB(LLMThought):
def on_tool_end(
self,
output: str,
color=None,
observation_prefix=None,
llm_prefix=None,
**kwargs: Any,
) -> None:
try:
self._container.markdown(
"\n\n".join(
["### Retrieved Documents:"]
+ [
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
for i, r in enumerate(json.loads(output))
]
)
)
except Exception as e:
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThoughtWithKB(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
|