Spaces:
Runtime error
Runtime error
changes to app.py and document_retrieval.py
Browse files- app.py +40 -40
- src/document_retrieval.py +11 -9
app.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
import logging
|
| 4 |
import yaml
|
| 5 |
import gradio as gr
|
| 6 |
-
import time
|
| 7 |
|
| 8 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
print(current_dir)
|
|
@@ -16,61 +14,61 @@ from utils.vectordb.vector_db import VectorDb
|
|
| 16 |
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
|
| 17 |
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
self.show_sources = True
|
| 27 |
-
self.sources_history = []
|
| 28 |
-
self.vectorstore = None
|
| 29 |
-
self.input_disabled = True
|
| 30 |
-
self.document_retrieval = None
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def handle_userinput(user_question):
|
| 37 |
if user_question:
|
| 38 |
try:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
response_time = time.time() - response_time
|
| 42 |
-
chat_state.chat_history.append((user_question, response["answer"]))
|
| 43 |
|
| 44 |
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
|
| 45 |
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
|
| 46 |
#state.sources_history.append(sources_text)
|
| 47 |
|
| 48 |
-
return
|
| 49 |
except Exception as e:
|
| 50 |
return f"An error occurred: {str(e)}", "" #, state.sources_history
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
def process_documents(files, save_location=None):
|
| 54 |
try:
|
| 55 |
#for doc in files:
|
| 56 |
_, _, text_chunks = parse_doc_universal(doc=files)
|
| 57 |
print(text_chunks)
|
| 58 |
#text_chunks = chat_state.document_retrieval.parse_doc(files)
|
| 59 |
-
embeddings =
|
| 60 |
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
|
| 61 |
-
vectorstore =
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
return "Complete! You can now ask questions."
|
| 67 |
except Exception as e:
|
| 68 |
-
return f"An error occurred while processing: {str(e)}"
|
| 69 |
|
| 70 |
def reset_conversation():
|
| 71 |
-
|
| 72 |
#chat_state.sources_history = []
|
| 73 |
-
return
|
| 74 |
|
| 75 |
def show_selection(model):
|
| 76 |
return f"You selected: {model}"
|
|
@@ -89,7 +87,8 @@ caution_text = """⚠️ Note: depending on the size of your document, this coul
|
|
| 89 |
"""
|
| 90 |
|
| 91 |
with gr.Blocks() as demo:
|
| 92 |
-
|
|
|
|
| 93 |
gr.Markdown("# Enterprise Knowledge Retriever",
|
| 94 |
elem_id="title")
|
| 95 |
|
|
@@ -108,8 +107,8 @@ with gr.Blocks() as demo:
|
|
| 108 |
process_btn = gr.Button("🔄 Process")
|
| 109 |
gr.Markdown(caution_text)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
process_btn.click(process_documents, inputs=[docs], outputs=
|
| 113 |
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
|
| 114 |
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
|
| 115 |
|
|
@@ -117,13 +116,14 @@ with gr.Blocks() as demo:
|
|
| 117 |
gr.Markdown("## 3️⃣ Chat with your document")
|
| 118 |
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
|
| 119 |
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
|
| 120 |
-
|
| 121 |
#show_sources = gr.Checkbox(label="Show sources", value=True)
|
| 122 |
sources_output = gr.Textbox(label="Sources", visible=False)
|
| 123 |
|
|
|
|
| 124 |
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
|
| 125 |
-
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
|
| 126 |
-
|
| 127 |
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
| 3 |
import yaml
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
|
| 6 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
print(current_dir)
|
|
|
|
| 14 |
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
|
| 15 |
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
|
| 16 |
|
| 17 |
+
#class ChatState:
|
| 18 |
+
# def __init__(self):
|
| 19 |
+
# self.conversation = None
|
| 20 |
+
# self.chat_history = []
|
| 21 |
+
# self.show_sources = True
|
| 22 |
+
# self.sources_history = []
|
| 23 |
+
# self.vectorstore = None
|
| 24 |
+
# self.input_disabled = True
|
| 25 |
+
# self.document_retrieval = None
|
| 26 |
|
| 27 |
+
chat_history = gr.State()
|
| 28 |
+
chat_history = []
|
| 29 |
+
vectorstore = gr.State()
|
| 30 |
+
document_retrieval = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
document_retrieval = DocumentRetrieval()
|
| 33 |
|
| 34 |
+
def handle_userinput(user_question, conversation):
|
|
|
|
|
|
|
| 35 |
if user_question:
|
| 36 |
try:
|
| 37 |
+
response = conversation.invoke({"question": user_question})
|
| 38 |
+
chat_history.append((user_question, response["answer"]))
|
|
|
|
|
|
|
| 39 |
|
| 40 |
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
|
| 41 |
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
|
| 42 |
#state.sources_history.append(sources_text)
|
| 43 |
|
| 44 |
+
return chat_history, "" #, state.sources_history
|
| 45 |
except Exception as e:
|
| 46 |
return f"An error occurred: {str(e)}", "" #, state.sources_history
|
| 47 |
+
else:
|
| 48 |
+
return "An error occurred", ""
|
| 49 |
+
#return chat_history, "" #, state.sources_history
|
| 50 |
|
| 51 |
+
def process_documents(files, conversation, save_location=None):
|
| 52 |
try:
|
| 53 |
#for doc in files:
|
| 54 |
_, _, text_chunks = parse_doc_universal(doc=files)
|
| 55 |
print(text_chunks)
|
| 56 |
#text_chunks = chat_state.document_retrieval.parse_doc(files)
|
| 57 |
+
embeddings = document_retrieval.load_embedding_model()
|
| 58 |
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
|
| 59 |
+
vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
|
| 60 |
+
#vectorstore = vectorstore
|
| 61 |
+
document_retrieval.init_retriever(vectorstore)
|
| 62 |
+
conversation = document_retrieval.get_qa_retrieval_chain()
|
| 63 |
+
#input_disabled = False
|
| 64 |
+
return conversation, "Complete! You can now ask questions."
|
| 65 |
except Exception as e:
|
| 66 |
+
return conversation, f"An error occurred while processing: {str(e)}"
|
| 67 |
|
| 68 |
def reset_conversation():
|
| 69 |
+
chat_history = []
|
| 70 |
#chat_state.sources_history = []
|
| 71 |
+
return chat_history, ""
|
| 72 |
|
| 73 |
def show_selection(model):
|
| 74 |
return f"You selected: {model}"
|
|
|
|
| 87 |
"""
|
| 88 |
|
| 89 |
with gr.Blocks() as demo:
|
| 90 |
+
conversation = gr.State()
|
| 91 |
+
|
| 92 |
gr.Markdown("# Enterprise Knowledge Retriever",
|
| 93 |
elem_id="title")
|
| 94 |
|
|
|
|
| 107 |
process_btn = gr.Button("🔄 Process")
|
| 108 |
gr.Markdown(caution_text)
|
| 109 |
|
| 110 |
+
# Preprocessing events
|
| 111 |
+
process_btn.click(process_documents, inputs=[docs, conversation], outputs=[conversation, setup_output])
|
| 112 |
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
|
| 113 |
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
|
| 114 |
|
|
|
|
| 116 |
gr.Markdown("## 3️⃣ Chat with your document")
|
| 117 |
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
|
| 118 |
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
|
| 119 |
+
clear_btn = gr.Button("Clear chat")
|
| 120 |
#show_sources = gr.Checkbox(label="Show sources", value=True)
|
| 121 |
sources_output = gr.Textbox(label="Sources", visible=False)
|
| 122 |
|
| 123 |
+
# Chatbot events
|
| 124 |
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
|
| 125 |
+
msg.submit(handle_userinput, inputs=[msg, conversation], outputs=[chatbot, msg])
|
| 126 |
+
clear_btn.click(reset_conversation, outputs=[chatbot,msg])
|
| 127 |
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
src/document_retrieval.py
CHANGED
|
@@ -21,7 +21,7 @@ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
|
|
| 21 |
sys.path.append(kit_dir)
|
| 22 |
sys.path.append(repo_dir)
|
| 23 |
|
| 24 |
-
import streamlit as st
|
| 25 |
|
| 26 |
from utils.model_wrappers.api_gateway import APIGateway
|
| 27 |
from utils.vectordb.vector_db import VectorDb
|
|
@@ -30,7 +30,7 @@ from utils.visual.env_utils import get_wandb_key
|
|
| 30 |
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
|
| 31 |
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
|
| 32 |
|
| 33 |
-
load_dotenv(os.path.join(kit_dir, '.env'))
|
| 34 |
|
| 35 |
|
| 36 |
from utils.parsing.sambaparse import parse_doc_universal
|
|
@@ -153,13 +153,15 @@ class DocumentRetrieval:
|
|
| 153 |
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
|
| 154 |
|
| 155 |
def set_llm(self):
|
| 156 |
-
if self.prod_mode:
|
| 157 |
-
|
| 158 |
-
else:
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
| 164 |
llm = APIGateway.load_llm(
|
| 165 |
type=self.api_info,
|
|
|
|
| 21 |
sys.path.append(kit_dir)
|
| 22 |
sys.path.append(repo_dir)
|
| 23 |
|
| 24 |
+
#import streamlit as st
|
| 25 |
|
| 26 |
from utils.model_wrappers.api_gateway import APIGateway
|
| 27 |
from utils.vectordb.vector_db import VectorDb
|
|
|
|
| 30 |
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
|
| 31 |
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
|
| 32 |
|
| 33 |
+
#load_dotenv(os.path.join(kit_dir, '.env'))
|
| 34 |
|
| 35 |
|
| 36 |
from utils.parsing.sambaparse import parse_doc_universal
|
|
|
|
| 153 |
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
|
| 154 |
|
| 155 |
def set_llm(self):
|
| 156 |
+
#if self.prod_mode:
|
| 157 |
+
# sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
|
| 158 |
+
#else:
|
| 159 |
+
# if 'SAMBANOVA_API_KEY' in st.session_state:
|
| 160 |
+
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
|
| 161 |
+
# else:
|
| 162 |
+
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
| 163 |
+
|
| 164 |
+
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
| 165 |
|
| 166 |
llm = APIGateway.load_llm(
|
| 167 |
type=self.api_info,
|