File size: 5,875 Bytes
e676d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import os, time
from app.vdr_session import *
from app.vdr_schemas import *
from st_clickable_images import clickable_images
from app.prompt_template import VDR_PROMPT
  
def page_vdr():
    st.header("Visual Document Retrieval")

    # Store session context
    if "vdr_session" not in st.session_state.keys():
        st.session_state["vdr_session"] = VDRSession()

    with st.sidebar:
        
        #api_key = st.text_input('Enter API Key:', type='password')
        api_key = os.getenv("GLOBAL_AIFS_API_KEY")

        check_api_key=st.session_state["vdr_session"].set_api_key(api_key)

        if check_api_key:
            st.success('API Key is valid!', icon='✅')
            avai_llms = st.session_state["vdr_session"].get_available_vlms()
            avai_embeds = st.session_state["vdr_session"].get_available_image_embeds()
            selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key)
            selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key)
            #st.session_state["vdr_session"].set_context(selected_llm, selected_embed)
        else:
            st.warning('Please enter valid credentials!', icon='⚠️')

    if check_api_key:
        
        with st.sidebar:
            uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key)

            if st.button("Add selected context", key="add_context", type="primary"):
                if uploaded_files:
                    try:
                        indexing_bar = st.progress(0, text="Indexing...")
                        if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar):
                            st.success('Indexing completed!')
                            indexing_bar.empty()
                            #st.rerun()
                        else:
                            st.warning('Files empty or not supported.', icon='⚠️')
                    except Exception as e:
                        st.error(f"Error during indexing: {e}")
                else:
                    st.warning('Please upload files first!', icon='⚠️')

            if st.button("🗑️ Remove all context", key="remove_context"):
                try:
                    st.session_state["vdr_session"].clear_context()
                    st.success("Context removed")
                    st.rerun()
                except Exception as e:
                    st.error(f"Error during removing context: {e}")

            
            top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim")
            #text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False)
            chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300)

        query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images)

        with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True):
            try:
                if len(query.strip()) > 2:
                    if query != st.session_state.get("last_query", None):
                        with st.spinner('Searching...'):
                            st.session_state["last_query"] = query
                            st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim)

                if st.session_state.get("result_images", []):
                    images = st.session_state["result_images"]

                    clicked = clickable_images(
                        images,
                        titles=[f"Image #{str(i)}" for i in range(len(images))],
                        div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
                        img_style={"margin": "5px", "height": "200px"},
                    )
                    st.write(f"**Retrieved by: {selected_embed}**")

                    @st.dialog(" ", width="large")
                    def show_selected_image(id):
                        st.markdown(f"**Similarity rank: {id}**")
                        st.image(images[id])
                    
                    if clicked > -1 and clicked != st.session_state.get("clicked", None):
                        show_selected_image(clicked)
                        st.session_state["clicked"] = clicked
                    
            except Exception as e:
                st.error(f"Error during search: {e}")

        if st.session_state.get("result_images", None):
            if st.button("Generate answer", key="ask", type="primary"):
                if len(query.strip()) > 2:
                    try:
                        with st.spinner('Generating response...'):
                            stream_response = st.session_state["vdr_session"].ask(
                                query=query, 
                                model=selected_llm, 
                                prompt_template= chat_prompt, 
                                retrieved_context=st.session_state["result_images"],
                                stream=True
                            )
                            #print(stream_response)
                            st.write_stream(stream_response)
                            st.write(f"**Answered by: {selected_llm}**")
                    except Exception as e:
                        st.error(f"Error during asking: {e}")
                else:
                    st.warning('Please enter query first!', icon='⚠️')