File size: 5,026 Bytes
a796108
06665fc
 
 
9061790
06665fc
d5a4cb4
06665fc
a796108
9061790
a796108
 
 
 
 
 
 
 
 
 
 
 
fab8405
 
 
a796108
 
 
 
9061790
a796108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9061790
 
 
 
 
 
 
 
 
 
 
 
 
d5a4cb4
 
 
 
 
 
 
 
9061790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06665fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import json
import textwrap
from typing import Dict, Any, List
from sql_formatter.core import format_sql
from langchain.callbacks.streamlit.streamlit_callback_handler import LLMThought, StreamlitCallbackHandler
from langchain.schema.output import LLMResult
from streamlit.delta_generator import DeltaGenerator

class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text="Working...")
        self.tokens_stream = ""
    
    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass
        
    def on_text(self, text: str, **kwargs) -> None:
        self.progress_bar.progress(value=0.2, text="Asking LLM...")
        
    def on_chain_end(self, outputs, **kwargs) -> None:
        self.progress_bar.progress(value=0.6, text='Searching in DB...')
        if 'repr' in outputs:
            st.markdown('### Generated Filter')
            st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
    
    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        pass

class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text='Searching DB...')
        self.status_bar = st.empty()
        self.prog_value = 0.0
        self.prog_map = {
            'langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain': 0.2,
            'langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain': 0.4,
            'langchain.chains.combine_documents.stuff.StuffDocumentsChain': 0.8
        }

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass
        
    def on_text(self, text: str, **kwargs) -> None:
        pass
        
    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        cid = '.'.join(serialized['id']) 
        if cid != 'langchain.chains.llm.LLMChain':
            self.progress_bar.progress(value=self.prog_map[cid], text=f'Running Chain `{cid}`...')
            self.prog_value = self.prog_map[cid]
        else:
            self.prog_value += 0.1
            self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')

    def on_chain_end(self, outputs, **kwargs) -> None:
        pass
    

class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.2

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        pass
        
    def on_llm_end(
        self,
        response: LLMResult,
        *args,
        **kwargs,
    ):
        text = response.generations[0][0].text
        if text.replace(' ', '').upper().startswith('SELECT'):
            st.write('We generated Vector SQL for you:')
            st.markdown(f'''```sql\n{format_sql(text, max_len=80)}\n```''')
            print(f"Vector SQL: {text}")
            self.prog_value += self.prog_interval
            self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
        
    def on_chain_start(self, serialized, inputs, **kwargs) -> None:
        cid = '.'.join(serialized['id']) 
        self.prog_value += self.prog_interval
        self.progress_bar.progress(value=self.prog_value, text=f'Running Chain `{cid}`...')
   
    def on_chain_end(self, outputs, **kwargs) -> None:
        pass
    
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
    def __init__(self) -> None:
        self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
        self.status_bar = st.empty()
        self.prog_value = 0
        self.prog_interval = 0.1
        
        
class LLMThoughtWithKB(LLMThought):
    def on_tool_end(self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any) -> None:
        try:
            self._container.markdown("\n\n".join(["### Retrieved Documents:"] + \
                [f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" 
                 for i, r in enumerate(json.loads(output))]))
        except Exception as e:
            super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
    
        
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
    
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        if self._current_thought is None:
            self._current_thought = LLMThoughtWithKB(
                parent_container=self._parent_container,
                expanded=self._expand_new_thoughts,
                collapse_on_complete=self._collapse_completed_thoughts,
                labeler=self._thought_labeler,
            )

        self._current_thought.on_llm_start(serialized, prompts)