File size: 3,866 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from streamlit_extras.add_vertical_space import add_vertical_space

from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
from backend.retrievers.self_query import process_self_query
from backend.retrievers.vector_sql_query import process_sql_query
from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO


def back_to_main():
    if USER_INFO in st.session_state:
        del st.session_state[USER_INFO]
    if USER_NAME in st.session_state:
        del st.session_state[USER_NAME]
    if JUMP_QUERY_ASK in st.session_state:
        del st.session_state[JUMP_QUERY_ASK]


def _render_table_selector() -> str:
    col1, col2 = st.columns(2)
    with col1:
        selected_table = st.selectbox(
            label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
            options=MYSCALE_TABLES.keys(),
        )
        MYSCALE_TABLES[selected_table].hint()
    with col2:
        add_vertical_space(1)
        st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
                icon='📚')
        MYSCALE_TABLES[selected_table].hint_sql()

    return selected_table


def render_retrievers():
    st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
    st.subheader('Please choose a public knowledge base to search.')
    selected_table = _render_table_selector()

    tab_sql, tab_self_query = st.tabs(
        tabs=['Vector SQL', 'Self-querying Retriever']
    )

    with tab_sql:
        render_tab_sql(selected_table)

    with tab_self_query:
        render_tab_self_query(selected_table)


def render_tab_sql(selected_table: str):
    st.warning(
        "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
        "the metadata we provide. This table allows filters to be established on the following metadata fields:",
        icon="⚠️")
    st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])

    cols = st.columns([8, 3, 3, 2])
    cols[0].text_input("Input your question:", key='query_sql')
    with cols[1].container():
        add_vertical_space(2)
        st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
    with cols[2].container():
        add_vertical_space(2)
        st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)

    if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
        process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)

    if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
        process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)


def render_tab_self_query(selected_table):
    st.warning(
        "When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
        "the metadata we provide. This table allows filters to be established on the following metadata fields:",
        icon="⚠️")
    st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])

    cols = st.columns([8, 3, 3, 2])
    cols[0].text_input("Input your question:", key='query_self')

    with cols[1].container():
        add_vertical_space(2)
        st.button("Retrieve from MyScaleDB ➡️", key='search_self')
    with cols[2].container():
        add_vertical_space(2)
        st.button("Retrieve and answer with LLM ➡️", key='ask_self')

    if st.session_state.search_self:
        process_self_query(selected_table, RetrieverButtons.self_query_from_db)

    if st.session_state.ask_self:
        process_self_query(selected_table, RetrieverButtons.self_query_with_llm)