File size: 2,503 Bytes
a9e9e50 c9433ef a9e9e50 c9433ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import openai
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from retriever import *
from chain import *
import gradio as gr
def load_embeddings_database_from_disk(persistence_directory, embeddings_generator):
"""
Load a Chroma vector database from disk.
This function loads a Chroma vector database from the specified directory on disk.
It expects the same persistence_directory and embedding function as used when creating the database.
Args:
persistence_directory (str): The directory where the database is stored on disk.
embeddings_generator (obj): The embeddings generator function that was used when creating the database.
Returns:
vector_database (obj): The loaded Chroma vector database.
"""
# Load the Chroma vector database from the persistence directory.
# The embedding_function parameter should be the same as the one used when the database was created.
vector_database = Chroma(persist_directory=persistence_directory, embedding_function=embeddings_generator)
return vector_database
# Specify the directory where the database will be stored when it's persisted.
persistence_directory = 'db'
# Create and persist the embeddings for the documents.
embeddings_generator = OpenAIEmbeddings(openai_api_key = openai.api_key)
# Load the Chroma vector database from disk.
vector_database = load_embeddings_database_from_disk(persistence_directory, embeddings_generator)
topk_documents = 2
# Creating the retriever on top documents.
retriever = initialize_document_retriever(topk_documents, vector_database)
qa_chain = create_question_answering_chain(retriever)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def bot(query):
llm_response = qa_chain.run({"query": query[-1][0]})
query[-1][1] = llm_response
return query
with gr.Blocks() as demo:
chatbot = gr.Chatbot([], elem_id="Retrieval Augmented Question Answering").style(height=750)
with gr.Row():
with gr.Column(scale=0.95):
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
).style(container=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
)
txt_msg.then(lambda: gr.update(interactive=True), None, txt, queue=False)
demo.launch()
|