Spaces:
Running
Running
import openai | |
import streamlit as st | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.vectorstores import FAISS | |
import re | |
import time | |
# class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain): | |
# def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: | |
# # Call the parent class's method to get the documents | |
# docs = super()._get_docs(inputs) | |
# # Modify the document metadata | |
# for doc in docs: | |
# doc.metadata['source'] = doc.metadata.pop('path') | |
# return docs | |
model_name = "intfloat/e5-large-v2" | |
model_kwargs = {'device': 'cuda'} | |
encode_kwargs = {'normalize_embeddings': False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
db = FAISS.load_local("IPCC_index_e5_1000_pdf", embeddings) | |
def generate_response(input_text): | |
docs = db.similarity_search(input_text,k=5) | |
json1 = docs[0].metadata | |
json2 = docs[1].metadata | |
json3 = docs[2].metadata | |
json4 = docs[3].metadata | |
json5 = docs[4].metadata | |
#st.write({"source1":json1["source"], "source2":json2["source"],"source3":json3["source"]}) | |
climate_TEMPLATE = """ You are ChatClimate, take a deep breath and provide an answer to educated general audience based on the context, and Format your answer in Markdown. :" | |
Context: {context} | |
Question: {question} | |
Answer: | |
check if you use the info below, if you used please add used source for in-text reference, if not used, do not add them . | |
[{source1} page {page1}] | |
[{source2} page {page2}] | |
[{source3} page {page3}] | |
[{source4} page {page4}] | |
[{source5} page {page5}] | |
Check if you use the source in your ansewer, make sure list used sources you refer to and their hyperlinks as below in a section named "sources": | |
[{source1} page {page1}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source1}.pdf#page={page1}) | |
[{source2} page {page2}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source2}.pdf#page={page2}) | |
[{source3} page {page3}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source3}.pdf#page={page3}) | |
[{source4} page {page4}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source4}.pdf#page={page4}) | |
[{source5} page {page5}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source5}.pdf#page={page5}) | |
at the end of your answer, make sure to add a short highlight of your answer in humor and make sure no more than 5 words. | |
Highlight: | |
""" | |
climate_PROMPT = PromptTemplate(input_variables=["question", "context"], | |
partial_variables={"source1":json1["source"], "source2":json2["source"], | |
"source3":json3["source"],"source4":json4["source"],"source5":json5["source"],"page1":json1["page"], | |
"page2":json2["page"],"page3":json3["page"],"page4":json4["page"],"page5":json5["page"]}, | |
template=climate_TEMPLATE, ) | |
#climate_PROMPT.partial(source = docs[0].metadata) | |
llm = ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
temperature=0.1, | |
max_tokens=2000, | |
openai_api_key=openai_api_key | |
) | |
# Define retriever | |
retriever = db.as_retriever(search_kwargs={"k": 5}) | |
qa_chain = RetrievalQA.from_chain_type(llm, | |
retriever=retriever, | |
chain_type="stuff", #"stuff", "map_reduce","refine", "map_rerank" | |
return_source_documents=True, | |
verbose=True, | |
chain_type_kwargs={"prompt": climate_PROMPT} | |
) | |
return qa_chain({'query': input_text}) | |
with st.sidebar: | |
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password") | |
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)" | |
st.title("π¬ππ‘οΈAsk question about Climate Change") | |
st.caption("π A Climate Change chatbot powered by OpenAI LLM") | |
#col1, col2, = st.columns(2) | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [{"role": "assistant", "content": "I'm a Chatbot who can answer your questions about the climate change!"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if prompt := st.chat_input(): | |
if not openai_api_key: | |
st.info("Please add your OpenAI API key to continue.") | |
st.stop() | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.chat_message("user").write(prompt) | |
result = generate_response(prompt) | |
result_r = result["result"] | |
index = result_r.find("Highlight:") | |
# Extract everything after "Highlight:" | |
match = re.search(r"Highlight: (.+)", result_r) | |
if match: | |
highlighted_text = match.group(1) | |
else: | |
highlighted_text="hello world" | |
st.session_state.messages.append({"role": "assistant", "content": result["result"]}) | |
st.chat_message("assistant").write(result_r) | |
#display_typing_effect(st.chat_message("assistant"), result_r) | |
#st.markdown(result['source_documents'][0]) | |
#st.markdown(result['source_documents'][1]) | |
#st.markdown(result['source_documents'][2]) | |
#st.markdown(result['source_documents'][3]) | |
#st.markdown(result['source_documents'][4]) | |
st.image("https://cataas.com/cat/says/"+highlighted_text) |