Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import faiss | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores.faiss import FAISS | |
os.environ['OPENAI_API_KEY'] = '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' | |
pickle_file = "open_ai.pkl" | |
index_file = "open_ai.index" | |
gpt_3_5 = ChatOpenAI(model_name='gpt-4',temperature=0.1) | |
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002') | |
chat_history = [] | |
memory = ConversationBufferWindowMemory(memory_key="chat_history") | |
gpt_3_5_index = None | |
system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can. | |
You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore. | |
Use the following pieces of context to answer the users question. | |
---------------- | |
{context}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) | |
def get_search_index(): | |
global gpt_3_5_index | |
if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0: | |
# Load index from pickle file | |
with open(pickle_file, "rb") as f: | |
search_index = pickle.load(f) | |
else: | |
search_index = create_index() | |
gpt_3_5_index = search_index | |
return search_index | |
def create_index(): | |
source_chunks = create_chunk_documents() | |
search_index = search_index_from_docs(source_chunks) | |
faiss.write_index(search_index.index, index_file) | |
# Save index to pickle file | |
with open(pickle_file, "wb") as f: | |
pickle.dump(search_index, f) | |
return search_index | |
def search_index_from_docs(source_chunks): | |
# print("source chunks: " + str(len(source_chunks))) | |
# print("embeddings: " + str(embeddings)) | |
search_index = FAISS.from_documents(source_chunks, embeddings) | |
return search_index | |
def get_html_files(): | |
loader = DirectoryLoader('docs', glob="**/*.html", loader_cls=UnstructuredHTMLLoader, recursive=True) | |
document_list = loader.load() | |
return document_list | |
def fetch_data_for_embeddings(): | |
document_list = get_text_files() | |
document_list.extend(get_html_files()) | |
print("document list" + str(len(document_list))) | |
return document_list | |
def get_text_files(): | |
loader = DirectoryLoader('docs', glob="**/*.txt", loader_cls=TextLoader, recursive=True) | |
document_list = loader.load() | |
return document_list | |
def create_chunk_documents(): | |
sources = fetch_data_for_embeddings() | |
splitter = CharacterTextSplitter(separator=" ", chunk_size=800, chunk_overlap=0) | |
source_chunks = splitter.split_documents(sources) | |
print("sources" + str(len(source_chunks))) | |
return source_chunks | |
def get_qa_chain(gpt_3_5_index): | |
global gpt_3_5 | |
# embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76) | |
# compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever()) | |
chain = ConversationalRetrievalChain.from_llm(gpt_3_5, gpt_3_5_index.as_retriever(), return_source_documents=True, | |
verbose=True, get_chat_history=get_chat_history, | |
combine_docs_chain_kwargs={"prompt": CHAT_PROMPT}) | |
return chain | |
def get_chat_history(inputs) -> str: | |
res = [] | |
for human, ai in inputs: | |
res.append(f"Human:{human}\nAI:{ai}") | |
return "\n".join(res) | |
def generate_answer(question) -> str: | |
global chat_history, gpt_3_5_index | |
gpt_3_5_chain = get_qa_chain(gpt_3_5_index) | |
result = gpt_3_5_chain( | |
{"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}}) | |
chat_history = [(question, result["answer"])] | |
sources = [] | |
print(result['answer']) | |
for document in result['source_documents']: | |
source = document.metadata['source'] | |
sources.append(source.split('/')[-1].split('.')[0]) | |
source = ',\n'.join(set(sources)) | |
return result['answer'] + '\nSOURCES: ' + source |