File size: 5,159 Bytes
a796108
06665fc
 
 
9061790
c6f6149
 
 
 
d5a4cb4
06665fc
a796108
c6f6149
9061790
a796108
 
 
c6f6149
a796108
 
c6f6149
a796108
 
c6f6149
a796108
c6f6149
 
 
fab8405
c6f6149
a796108
 
 
c6f6149
9061790
a796108
c6f6149
a796108
 
 
c6f6149
 
 
a796108
 
 
 
c6f6149
a796108
 
c6f6149
a796108
c6f6149
 
 
 
 
a796108
 
 
c6f6149
 
 
a796108
 
9061790
c6f6149
9061790
 
 
c6f6149
9061790
 
 
 
 
 
c6f6149
d5a4cb4
 
 
 
 
 
 
c6f6149
 
 
9061790
 
 
c6f6149
9061790
c6f6149
9061790
c6f6149
 
 
 
9061790
 
c6f6149
 
9061790
 
c6f6149
9061790
 
06665fc
c6f6149
 
06665fc
c6f6149
 
 
 
 
 
 
 
06665fc
c6f6149
 
 
 
 
 
 
 
 
06665fc
 
c6f6149
 
06665fc
 
 
 
 
 
 
 
 
 
 
 
c6f6149
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import streamlit as st
import json
import textwrap
from typing import Dict, Any, List
from sql_formatter.core import format_sql
from langchain.callbacks.streamlit.streamlit_callback_handler import (
    LLMThought,
    StreamlitCallbackHandler,
)
from langchain.schema.output import LLMResult
from streamlit.delta_generator import DeltaGenerator


class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Working...")
        self.tokens_stream = ""

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass

    def on_text(self, text: str, **kwargs) -> None:
        self.progress_bar.progress(value=0.2, text="Asking LLM...")

    def on_chain_end(self, outputs, **kwargs) -> None:
        self.progress_bar.progress(value=0.6, text="Searching in DB...")
        if "repr" in outputs:
            st.markdown("### Generated Filter")
            st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)

    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        pass


class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Searching DB...")
        self.status_bar = st.empty()
        self.prog_value = 0.0
        self.prog_map = {
            "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
            "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
            "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
        }

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass

    def on_text(self, text: str, **kwargs) -> None:
        pass

    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        cid = ".".join(serialized["id"])
        if cid != "langchain.chains.llm.LLMChain":
            self.progress_bar.progress(
                value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
            )
            self.prog_value = self.prog_map[cid]
        else:
            self.prog_value += 0.1
            self.progress_bar.progress(
                value=self.prog_value, text=f"Running Chain `{cid}`..."
            )

    def on_chain_end(self, outputs, **kwargs) -> None:
        pass


class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.2

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass

    def on_llm_end(
        self,
        response: LLMResult,
        *args,
        **kwargs,
    ):
        text = response.generations[0][0].text
        if text.replace(" ", "").upper().startswith("SELECT"):
            st.write("We generated Vector SQL for you:")
            st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
            print(f"Vector SQL: {text}")
            self.prog_value += self.prog_interval
            self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")

    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        cid = ".".join(serialized["id"])
        self.prog_value += self.prog_interval
        self.progress_bar.progress(
            value=self.prog_value, text=f"Running Chain `{cid}`..."
        )

    def on_chain_end(self, outputs, **kwargs) -> None:
        pass


class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.1


class LLMThoughtWithKB(LLMThought):
    def on_tool_end(
        self,
        output: str,
        color=None,
        observation_prefix=None,
        llm_prefix=None,
        **kwargs: Any,
    ) -> None:
        try:
            self._container.markdown(
                "\n\n".join(
                    ["### Retrieved Documents:"]
                    + [
                        f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
                        for i, r in enumerate(json.loads(output))
                    ]
                )
            )
        except Exception as e:
            super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)


class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        if self._current_thought is None:
            self._current_thought = LLMThoughtWithKB(
                parent_container=self._parent_container,
                expanded=self._expand_new_thoughts,
                collapse_on_complete=self._collapse_completed_thoughts,
                labeler=self._thought_labeler,
            )

        self._current_thought.on_llm_start(serialized, prompts)