import os |
import pickle |
import faiss |
import langchain |
from langchain import HuggingFaceHub |
from langchain.cache import InMemoryCache |
from langchain.chains import ConversationalRetrievalChain |
from langchain.chat_models import ChatOpenAI |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings |
from langchain.memory import ConversationBufferWindowMemory |
from langchain.prompts.chat import ( |
ChatPromptTemplate, |
HumanMessagePromptTemplate, |
SystemMessagePromptTemplate, |
) |
from langchain.text_splitter import CharacterTextSplitter |
from langchain.vectorstores.faiss import FAISS |
from mapping import FILE_URL_MAPPING |
from memory import CustomMongoDBChatMessageHistory |
langchain.llm_cache = InMemoryCache() |
models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"] |
pickle_file = "_vs.pkl" |
index_file = "_vs.index" |
models_folder = "models/" |
MONGO_DB_URL = os.environ['MONGO_DB_URL'] |
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1) |
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002') |
message_history = CustomMongoDBChatMessageHistory( |
connection_string=MONGO_DB_URL, session_id='session_id', database_name='coursera_bots', |
collection_name='3d_printing_revolution' |
) |
memory = ConversationBufferWindowMemory(memory_key="chat_history", k=4) |
vectorstore_index = None |
system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can. |
You are a teaching assistant for a Coursera Course: The 3D Printing Revolution and can answer any question about that using vectorstore or context. |
Use the following pieces of context to answer the users question. |
---------------- |
{context}""" |
messages = [ |
SystemMessagePromptTemplate.from_template(system_template), |
HumanMessagePromptTemplate.from_template("{question}"), |
] |
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) |
def set_session_id(session_id): |
global message_history, memory |
if message_history.session_id == session_id: |
print("Session id already set: " + str(message_history.session_id)) |
else: |
print("Setting session id to " + str(session_id)) |
message_history = CustomMongoDBChatMessageHistory( |
connection_string=MONGO_DB_URL, session_id=session_id, database_name='coursera_bots', |
collection_name='printing_3d_revolution' |
) |
memory = ConversationBufferWindowMemory(memory_key="chat_history", chat_memory=message_history, k=10, |
return_messages=True) |
def set_model_and_embeddings(model): |
set_model(model) |
def set_model(model): |
global llm |
print("Setting model to " + str(model)) |
if model == "GPT-3.5": |
print("Loading GPT-3.5") |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.1) |
elif model == "GPT-4": |
print("Loading GPT-4") |
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1) |
elif model == "Flan UL2": |
print("Loading Flan-UL2") |
llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens": 500}) |
elif model == "Flan T5": |
print("Loading Flan T5") |
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1}) |
else: |
print("Loading GPT-3.5 from else") |
llm = ChatOpenAI(model_name="text-davinci-002", temperature=0.1) |
def set_embeddings(model): |
global embeddings |
if model == "GPT-3.5" or model == "GPT-4": |
print("Loading OpenAI embeddings") |
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002') |
elif model == "Flan UL2" or model == "Flan T5": |
print("Loading Hugging Face embeddings") |
embeddings = HuggingFaceHubEmbeddings(repo_id="sentence-transformers/all-MiniLM-L6-v2") |
def get_search_index(model): |
global vectorstore_index |
if os.path.isfile(get_file_path(model, pickle_file)) and os.path.isfile( |
get_file_path(model, index_file)) and os.path.getsize(get_file_path(model, pickle_file)) > 0: |
with open(get_file_path(model, pickle_file), "rb") as f: |
search_index = pickle.load(f) |
print("Loaded index") |
else: |
search_index = create_index(model) |
print("Created index") |
vectorstore_index = search_index |
return search_index |
def create_index(model): |
source_chunks = create_chunk_documents() |
search_index = search_index_from_docs(source_chunks) |
faiss.write_index(search_index.index, get_file_path(model, index_file)) |
with open(get_file_path(model, pickle_file), "wb") as f: |
pickle.dump(search_index, f) |
return search_index |
def get_file_path(model, file): |
if model == "GPT-3.5" or model == "GPT-4": |
return models_folder + "openai" + file |
else: |
return models_folder + "hf" + file |
def search_index_from_docs(source_chunks): |
search_index = FAISS.from_documents(source_chunks, embeddings) |
return search_index |
def get_html_files(): |
loader = DirectoryLoader('docs', glob="**/*.html", loader_cls=UnstructuredHTMLLoader, recursive=True) |
document_list = loader.load() |
return document_list |
def fetch_data_for_embeddings(): |
document_list = get_text_files() |
document_list.extend(get_html_files()) |
for document in document_list: |
document.metadata["url"] = FILE_URL_MAPPING.get(document.metadata["source"]) |
print("document list: " + str(len(document_list))) |
return document_list |
def get_text_files(): |
loader = DirectoryLoader('docs', glob="**/*.txt", loader_cls=TextLoader, recursive=True) |
document_list = loader.load() |
return document_list |
def create_chunk_documents(): |
sources = fetch_data_for_embeddings() |
splitter = CharacterTextSplitter(separator=" ", chunk_size=800, chunk_overlap=0) |
source_chunks = splitter.split_documents(sources) |
print("chunks: " + str(len(source_chunks))) |
return source_chunks |
def get_qa_chain(vectorstore_index): |
global llm |
print(llm) |
retriever = vectorstore_index.as_retriever(search_type="similarity_score_threshold", |
search_kwargs={"score_threshold": .7}) |
chain = ConversationalRetrievalChain.from_llm(llm, retriever, return_source_documents=True, |
verbose=True, |
combine_docs_chain_kwargs={"prompt": CHAT_PROMPT}) |
return chain |
def get_chat_history(inputs) -> str: |
res = [] |
for human, ai in inputs: |
res.append(f"Human:{human}\nAI:{ai}") |
return "\n".join(res) |
def generate_answer(question) -> str: |
global vectorstore_index |
chain = get_qa_chain(vectorstore_index) |
history = memory.chat_memory.messages[-4:] |
result = chain( |
{"question": question, "chat_history": history}) |
save_chat_history(question, result) |
sources = [] |
print(result) |
for document in result['source_documents']: |
sources.append("\n" + document.metadata['url']) |
print(sources) |
source = ',\n'.join(set(sources)) |
return result['answer'] + '\nSOURCES: ' + source |
def save_chat_history(question, result): |
memory.chat_memory.add_user_message(question) |
memory.chat_memory.add_ai_message(result["answer"]) |
print("chat history after saving: " + str(memory.chat_memory.messages)) |