ClimateChat / app.py
zliang's picture
Create app.py
ed7625a
raw
history blame
5.87 kB
import openai
import streamlit as st
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores import FAISS
import re
import time
# class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
# def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
# # Call the parent class's method to get the documents
# docs = super()._get_docs(inputs)
# # Modify the document metadata
# for doc in docs:
# doc.metadata['source'] = doc.metadata.pop('path')
# return docs
model_name = "intfloat/e5-large-v2"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
db = FAISS.load_local("IPCC_index_e5_1000_pdf", embeddings)
def generate_response(input_text):
docs = db.similarity_search(input_text,k=5)
json1 = docs[0].metadata
json2 = docs[1].metadata
json3 = docs[2].metadata
json4 = docs[3].metadata
json5 = docs[4].metadata
#st.write({"source1":json1["source"], "source2":json2["source"],"source3":json3["source"]})
climate_TEMPLATE = """ You are ChatClimate, take a deep breath and provide an answer to educated general audience based on the context, and Format your answer in Markdown. :"
Context: {context}
Question: {question}
Answer:
check if you use the info below, if you used please add used source for in-text reference, if not used, do not add them .
[{source1} page {page1}]
[{source2} page {page2}]
[{source3} page {page3}]
[{source4} page {page4}]
[{source5} page {page5}]
Check if you use the source in your ansewer, make sure list used sources you refer to and their hyperlinks as below in a section named "sources":
[{source1} page {page1}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source1}.pdf#page={page1})
[{source2} page {page2}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source2}.pdf#page={page2})
[{source3} page {page3}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source3}.pdf#page={page3})
[{source4} page {page4}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source4}.pdf#page={page4})
[{source5} page {page5}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source5}.pdf#page={page5})
at the end of your answer, make sure to add a short highlight of your answer in humor and make sure no more than 5 words.
Highlight:
"""
climate_PROMPT = PromptTemplate(input_variables=["question", "context"],
partial_variables={"source1":json1["source"], "source2":json2["source"],
"source3":json3["source"],"source4":json4["source"],"source5":json5["source"],"page1":json1["page"],
"page2":json2["page"],"page3":json3["page"],"page4":json4["page"],"page5":json5["page"]},
template=climate_TEMPLATE, )
#climate_PROMPT.partial(source = docs[0].metadata)
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.1,
max_tokens=2000,
openai_api_key=openai_api_key
)
# Define retriever
retriever = db.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(llm,
retriever=retriever,
chain_type="stuff", #"stuff", "map_reduce","refine", "map_rerank"
return_source_documents=True,
verbose=True,
chain_type_kwargs={"prompt": climate_PROMPT}
)
return qa_chain({'query': input_text})
with st.sidebar:
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password")
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
st.title("πŸ’¬πŸŒπŸŒ‘οΈAsk question about Climate Change")
st.caption("πŸš€ A Climate Change chatbot powered by OpenAI LLM")
#col1, col2, = st.columns(2)
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "I'm a Chatbot who can answer your questions about the climate change!"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if prompt := st.chat_input():
if not openai_api_key:
st.info("Please add your OpenAI API key to continue.")
st.stop()
st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
result = generate_response(prompt)
result_r = result["result"]
index = result_r.find("Highlight:")
# Extract everything after "Highlight:"
match = re.search(r"Highlight: (.+)", result_r)
if match:
highlighted_text = match.group(1)
else:
highlighted_text="hello world"
st.session_state.messages.append({"role": "assistant", "content": result["result"]})
st.chat_message("assistant").write(result_r)
#display_typing_effect(st.chat_message("assistant"), result_r)
#st.markdown(result['source_documents'][0])
#st.markdown(result['source_documents'][1])
#st.markdown(result['source_documents'][2])
#st.markdown(result['source_documents'][3])
#st.markdown(result['source_documents'][4])
st.image("https://cataas.com/cat/says/"+highlighted_text)