Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import streamlit as st
|
3 |
+
from langchain.llms import OpenAI
|
4 |
+
from langchain.chat_models import ChatOpenAI
|
5 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
+
from langchain.chains import RetrievalQA
|
7 |
+
|
8 |
+
from langchain.prompts.prompt import PromptTemplate
|
9 |
+
|
10 |
+
from langchain.vectorstores import FAISS
|
11 |
+
import re
|
12 |
+
import time
|
13 |
+
# class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
14 |
+
# def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
15 |
+
# # Call the parent class's method to get the documents
|
16 |
+
# docs = super()._get_docs(inputs)
|
17 |
+
# # Modify the document metadata
|
18 |
+
# for doc in docs:
|
19 |
+
# doc.metadata['source'] = doc.metadata.pop('path')
|
20 |
+
# return docs
|
21 |
+
|
22 |
+
model_name = "intfloat/e5-large-v2"
|
23 |
+
model_kwargs = {'device': 'cuda'}
|
24 |
+
encode_kwargs = {'normalize_embeddings': False}
|
25 |
+
embeddings = HuggingFaceEmbeddings(
|
26 |
+
model_name=model_name,
|
27 |
+
model_kwargs=model_kwargs,
|
28 |
+
encode_kwargs=encode_kwargs
|
29 |
+
)
|
30 |
+
|
31 |
+
db = FAISS.load_local("IPCC_index_e5_1000_pdf", embeddings)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def generate_response(input_text):
|
36 |
+
docs = db.similarity_search(input_text,k=5)
|
37 |
+
|
38 |
+
json1 = docs[0].metadata
|
39 |
+
json2 = docs[1].metadata
|
40 |
+
json3 = docs[2].metadata
|
41 |
+
json4 = docs[3].metadata
|
42 |
+
json5 = docs[4].metadata
|
43 |
+
#st.write({"source1":json1["source"], "source2":json2["source"],"source3":json3["source"]})
|
44 |
+
|
45 |
+
|
46 |
+
climate_TEMPLATE = """ You are ChatClimate, take a deep breath and provide an answer to educated general audience based on the context, and Format your answer in Markdown. :"
|
47 |
+
|
48 |
+
Context: {context}
|
49 |
+
|
50 |
+
Question: {question}
|
51 |
+
|
52 |
+
Answer:
|
53 |
+
|
54 |
+
|
55 |
+
check if you use the info below, if you used please add used source for in-text reference, if not used, do not add them .
|
56 |
+
|
57 |
+
|
58 |
+
[{source1} page {page1}]
|
59 |
+
[{source2} page {page2}]
|
60 |
+
[{source3} page {page3}]
|
61 |
+
[{source4} page {page4}]
|
62 |
+
[{source5} page {page5}]
|
63 |
+
|
64 |
+
Check if you use the source in your ansewer, make sure list used sources you refer to and their hyperlinks as below in a section named "sources":
|
65 |
+
|
66 |
+
[{source1} page {page1}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source1}.pdf#page={page1})
|
67 |
+
[{source2} page {page2}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source2}.pdf#page={page2})
|
68 |
+
[{source3} page {page3}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source3}.pdf#page={page3})
|
69 |
+
[{source4} page {page4}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source4}.pdf#page={page4})
|
70 |
+
[{source5} page {page5}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source5}.pdf#page={page5})
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
at the end of your answer, make sure to add a short highlight of your answer in humor and make sure no more than 5 words.
|
75 |
+
|
76 |
+
Highlight:
|
77 |
+
"""
|
78 |
+
climate_PROMPT = PromptTemplate(input_variables=["question", "context"],
|
79 |
+
partial_variables={"source1":json1["source"], "source2":json2["source"],
|
80 |
+
"source3":json3["source"],"source4":json4["source"],"source5":json5["source"],"page1":json1["page"],
|
81 |
+
"page2":json2["page"],"page3":json3["page"],"page4":json4["page"],"page5":json5["page"]},
|
82 |
+
template=climate_TEMPLATE, )
|
83 |
+
|
84 |
+
#climate_PROMPT.partial(source = docs[0].metadata)
|
85 |
+
|
86 |
+
llm = ChatOpenAI(
|
87 |
+
model_name="gpt-3.5-turbo",
|
88 |
+
temperature=0.1,
|
89 |
+
max_tokens=2000,
|
90 |
+
openai_api_key=openai_api_key
|
91 |
+
)
|
92 |
+
|
93 |
+
# Define retriever
|
94 |
+
retriever = db.as_retriever(search_kwargs={"k": 5})
|
95 |
+
|
96 |
+
qa_chain = RetrievalQA.from_chain_type(llm,
|
97 |
+
retriever=retriever,
|
98 |
+
chain_type="stuff", #"stuff", "map_reduce","refine", "map_rerank"
|
99 |
+
return_source_documents=True,
|
100 |
+
verbose=True,
|
101 |
+
chain_type_kwargs={"prompt": climate_PROMPT}
|
102 |
+
)
|
103 |
+
|
104 |
+
return qa_chain({'query': input_text})
|
105 |
+
|
106 |
+
|
107 |
+
with st.sidebar:
|
108 |
+
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password")
|
109 |
+
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
|
110 |
+
|
111 |
+
st.title("๐ฌ๐๐ก๏ธAsk question about Climate Change")
|
112 |
+
st.caption("๐ A Climate Change chatbot powered by OpenAI LLM")
|
113 |
+
#col1, col2, = st.columns(2)
|
114 |
+
|
115 |
+
|
116 |
+
if "messages" not in st.session_state:
|
117 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "I'm a Chatbot who can answer your questions about the climate change!"}]
|
118 |
+
|
119 |
+
for msg in st.session_state.messages:
|
120 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
121 |
+
|
122 |
+
if prompt := st.chat_input():
|
123 |
+
if not openai_api_key:
|
124 |
+
st.info("Please add your OpenAI API key to continue.")
|
125 |
+
st.stop()
|
126 |
+
|
127 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
128 |
+
st.chat_message("user").write(prompt)
|
129 |
+
result = generate_response(prompt)
|
130 |
+
result_r = result["result"]
|
131 |
+
index = result_r.find("Highlight:")
|
132 |
+
|
133 |
+
# Extract everything after "Highlight:"
|
134 |
+
match = re.search(r"Highlight: (.+)", result_r)
|
135 |
+
if match:
|
136 |
+
highlighted_text = match.group(1)
|
137 |
+
else:
|
138 |
+
highlighted_text="hello world"
|
139 |
+
st.session_state.messages.append({"role": "assistant", "content": result["result"]})
|
140 |
+
st.chat_message("assistant").write(result_r)
|
141 |
+
#display_typing_effect(st.chat_message("assistant"), result_r)
|
142 |
+
#st.markdown(result['source_documents'][0])
|
143 |
+
#st.markdown(result['source_documents'][1])
|
144 |
+
#st.markdown(result['source_documents'][2])
|
145 |
+
#st.markdown(result['source_documents'][3])
|
146 |
+
#st.markdown(result['source_documents'][4])
|
147 |
+
|
148 |
+
|
149 |
+
st.image("https://cataas.com/cat/says/"+highlighted_text)
|