Spaces:
Runtime error
Runtime error
File size: 4,671 Bytes
0e59554 7d98b2f 0e59554 aad01e3 0e59554 7d98b2f aad01e3 7d98b2f 0e59554 aad01e3 0e59554 aad01e3 0e59554 aad01e3 7d98b2f 0e59554 aad01e3 0e59554 aad01e3 0e59554 aad01e3 0e59554 38ab472 0e59554 aad01e3 0e59554 dc53092 0e59554 aad01e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import pickle
import faiss
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
os.environ['OPENAI_API_KEY'] = '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'
pickle_file = "open_ai.pkl"
index_file = "open_ai.index"
gpt_3_5 = ChatOpenAI(model_name='gpt-4',temperature=0.1)
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
chat_history = []
memory = ConversationBufferWindowMemory(memory_key="chat_history")
gpt_3_5_index = None
system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore.
Use the following pieces of context to answer the users question.
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
def get_search_index():
global gpt_3_5_index
if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0:
# Load index from pickle file
with open(pickle_file, "rb") as f:
search_index = pickle.load(f)
else:
search_index = create_index()
gpt_3_5_index = search_index
return search_index
def create_index():
source_chunks = create_chunk_documents()
search_index = search_index_from_docs(source_chunks)
faiss.write_index(search_index.index, index_file)
# Save index to pickle file
with open(pickle_file, "wb") as f:
pickle.dump(search_index, f)
return search_index
def search_index_from_docs(source_chunks):
# print("source chunks: " + str(len(source_chunks)))
# print("embeddings: " + str(embeddings))
search_index = FAISS.from_documents(source_chunks, embeddings)
return search_index
def get_html_files():
loader = DirectoryLoader('docs', glob="**/*.html", loader_cls=UnstructuredHTMLLoader, recursive=True)
document_list = loader.load()
return document_list
def fetch_data_for_embeddings():
document_list = get_text_files()
document_list.extend(get_html_files())
print("document list" + str(len(document_list)))
return document_list
def get_text_files():
loader = DirectoryLoader('docs', glob="**/*.txt", loader_cls=TextLoader, recursive=True)
document_list = loader.load()
return document_list
def create_chunk_documents():
sources = fetch_data_for_embeddings()
splitter = CharacterTextSplitter(separator=" ", chunk_size=800, chunk_overlap=0)
source_chunks = splitter.split_documents(sources)
print("sources" + str(len(source_chunks)))
return source_chunks
def get_qa_chain(gpt_3_5_index):
global gpt_3_5
# embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
# compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
chain = ConversationalRetrievalChain.from_llm(gpt_3_5, gpt_3_5_index.as_retriever(), return_source_documents=True,
verbose=True, get_chat_history=get_chat_history,
combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
return chain
def get_chat_history(inputs) -> str:
res = []
for human, ai in inputs:
res.append(f"Human:{human}\nAI:{ai}")
return "\n".join(res)
def generate_answer(question) -> str:
global chat_history, gpt_3_5_index
gpt_3_5_chain = get_qa_chain(gpt_3_5_index)
result = gpt_3_5_chain(
{"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}})
chat_history = [(question, result["answer"])]
sources = []
print(result['answer'])
for document in result['source_documents']:
source = document.metadata['source']
sources.append(source.split('/')[-1].split('.')[0])
source = ',\n'.join(set(sources))
return result['answer'] + '\nSOURCES: ' + source |