File size: 5,399 Bytes
e7de495 746ce21 e7de495 746ce21 e7de495 fa183f9 e7de495 dfe4dc9 e7de495 5a1498d e7de495 5a1498d e7de495 7803fd6 e7de495 7803fd6 e7de495 fa183f9 e7de495 8a72047 7803fd6 60b15ef 8a72047 fb1fbe0 a4a183e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import CacheBackedEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
import chainlit as cl
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
system_template = """
Use the following pieces of context to answer the user's question.
Please respond as if you are a customer support assistant ('kundeservice AI assistent') for Daysoff.
By default, you respond in Norwegian language (unless asked otherwise)
using a warm, direct, and professional tone.
Your expertise covers FAQs, and privacy policies.
If you don't know the answer, just say that you don't know, don't try to make up an answer and
politely redirect users to customer service at [email protected].
You can make inferences based on the context as long as it still faithfully represents the feedback.
Example of your response should be:
```
The answer is foo
```
Begin!
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "Consulting Daysoff data"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
# --builds FAISS index from csv
loader = CSVLoader(file_path="./data/total_faq.csv", source_column="Answer") # columns in csv: Answer, Question, Info_Url
data = loader.load()
documents = text_splitter.transform_documents(data)
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings_model, store, namespace=core_embeddings_model.model
)
# --make async docsearch
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
chain = RetrievalQA.from_chain_type(
ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7, streaming=True),
chain_type="stuff",
return_source_documents=False,
retriever=docsearch.as_retriever(),
chain_type_kwargs = {"prompt": prompt}
)
msg.content = f"Index built!"
await msg.send()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb])
answer = res["result"]
source_elements = []
visited_sources = set()
# --documents, user session
docs = res["source_documents"]
metadatas = [doc.metadata for doc in docs]
all_sources = [m["source"] for m in metadatas]
for doc, metadata in zip(docs, metadatas):
row_index = metadata.get("row_index", -1) # --when row_index@metadata
if row_index in [2, 8, 14]: # --incl. only rows 2, 8, and 14
source = metadata.get("source", "")
if source and source not in visited_sources:
visited_sources.add(source)
source_elements.append(
cl.Text(content="https://www.daysoff.no" + source, name="Info_Url")
)
if source_elements:
answer += f"\nSources: {', '.join([e.content for e in source_elements])}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()
"""
for source in all_sources:
if source in visited_sources:
continue
visited_sources.add(source)
# --create text element referenced in message
source_elements.append(
cl.Text(content="https://www.daysoff.no" + source, name="Info_Url")
)
if source_elements:
answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()
"""
for doc, metadata in zip(docs, metadatas):
row_index = metadata.get("row_index", -1) # --when row_index@metadata
if row_index in [2, 8, 14]: # --incl. only rows 2, 8, and 14
source = metadata.get("source", "")
if source and source not in visited_sources:
visited_sources.add(source)
source_elements.append(
cl.Text(content="https://www.daysoff.no" + source, name="Info_Url")
)
if source_elements:
answer += f"\nSources: {', '.join([e.content for e in source_elements])}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()
|