zliang commited on
Commit
ed7625a
ยท
1 Parent(s): 5d9d034

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import streamlit as st
3
+ from langchain.llms import OpenAI
4
+ from langchain.chat_models import ChatOpenAI
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.chains import RetrievalQA
7
+
8
+ from langchain.prompts.prompt import PromptTemplate
9
+
10
+ from langchain.vectorstores import FAISS
11
+ import re
12
+ import time
13
+ # class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
14
+ # def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
15
+ # # Call the parent class's method to get the documents
16
+ # docs = super()._get_docs(inputs)
17
+ # # Modify the document metadata
18
+ # for doc in docs:
19
+ # doc.metadata['source'] = doc.metadata.pop('path')
20
+ # return docs
21
+
22
+ model_name = "intfloat/e5-large-v2"
23
+ model_kwargs = {'device': 'cuda'}
24
+ encode_kwargs = {'normalize_embeddings': False}
25
+ embeddings = HuggingFaceEmbeddings(
26
+ model_name=model_name,
27
+ model_kwargs=model_kwargs,
28
+ encode_kwargs=encode_kwargs
29
+ )
30
+
31
+ db = FAISS.load_local("IPCC_index_e5_1000_pdf", embeddings)
32
+
33
+
34
+
35
+ def generate_response(input_text):
36
+ docs = db.similarity_search(input_text,k=5)
37
+
38
+ json1 = docs[0].metadata
39
+ json2 = docs[1].metadata
40
+ json3 = docs[2].metadata
41
+ json4 = docs[3].metadata
42
+ json5 = docs[4].metadata
43
+ #st.write({"source1":json1["source"], "source2":json2["source"],"source3":json3["source"]})
44
+
45
+
46
+ climate_TEMPLATE = """ You are ChatClimate, take a deep breath and provide an answer to educated general audience based on the context, and Format your answer in Markdown. :"
47
+
48
+ Context: {context}
49
+
50
+ Question: {question}
51
+
52
+ Answer:
53
+
54
+
55
+ check if you use the info below, if you used please add used source for in-text reference, if not used, do not add them .
56
+
57
+
58
+ [{source1} page {page1}]
59
+ [{source2} page {page2}]
60
+ [{source3} page {page3}]
61
+ [{source4} page {page4}]
62
+ [{source5} page {page5}]
63
+
64
+ Check if you use the source in your ansewer, make sure list used sources you refer to and their hyperlinks as below in a section named "sources":
65
+
66
+ [{source1} page {page1}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source1}.pdf#page={page1})
67
+ [{source2} page {page2}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source2}.pdf#page={page2})
68
+ [{source3} page {page3}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source3}.pdf#page={page3})
69
+ [{source4} page {page4}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source4}.pdf#page={page4})
70
+ [{source5} page {page5}](https://www.ipcc.ch/report/ar6/wg3/downloads/report/{source5}.pdf#page={page5})
71
+
72
+
73
+
74
+ at the end of your answer, make sure to add a short highlight of your answer in humor and make sure no more than 5 words.
75
+
76
+ Highlight:
77
+ """
78
+ climate_PROMPT = PromptTemplate(input_variables=["question", "context"],
79
+ partial_variables={"source1":json1["source"], "source2":json2["source"],
80
+ "source3":json3["source"],"source4":json4["source"],"source5":json5["source"],"page1":json1["page"],
81
+ "page2":json2["page"],"page3":json3["page"],"page4":json4["page"],"page5":json5["page"]},
82
+ template=climate_TEMPLATE, )
83
+
84
+ #climate_PROMPT.partial(source = docs[0].metadata)
85
+
86
+ llm = ChatOpenAI(
87
+ model_name="gpt-3.5-turbo",
88
+ temperature=0.1,
89
+ max_tokens=2000,
90
+ openai_api_key=openai_api_key
91
+ )
92
+
93
+ # Define retriever
94
+ retriever = db.as_retriever(search_kwargs={"k": 5})
95
+
96
+ qa_chain = RetrievalQA.from_chain_type(llm,
97
+ retriever=retriever,
98
+ chain_type="stuff", #"stuff", "map_reduce","refine", "map_rerank"
99
+ return_source_documents=True,
100
+ verbose=True,
101
+ chain_type_kwargs={"prompt": climate_PROMPT}
102
+ )
103
+
104
+ return qa_chain({'query': input_text})
105
+
106
+
107
+ with st.sidebar:
108
+ openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password")
109
+ "[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
110
+
111
+ st.title("๐Ÿ’ฌ๐ŸŒ๐ŸŒก๏ธAsk question about Climate Change")
112
+ st.caption("๐Ÿš€ A Climate Change chatbot powered by OpenAI LLM")
113
+ #col1, col2, = st.columns(2)
114
+
115
+
116
+ if "messages" not in st.session_state:
117
+ st.session_state["messages"] = [{"role": "assistant", "content": "I'm a Chatbot who can answer your questions about the climate change!"}]
118
+
119
+ for msg in st.session_state.messages:
120
+ st.chat_message(msg["role"]).write(msg["content"])
121
+
122
+ if prompt := st.chat_input():
123
+ if not openai_api_key:
124
+ st.info("Please add your OpenAI API key to continue.")
125
+ st.stop()
126
+
127
+ st.session_state.messages.append({"role": "user", "content": prompt})
128
+ st.chat_message("user").write(prompt)
129
+ result = generate_response(prompt)
130
+ result_r = result["result"]
131
+ index = result_r.find("Highlight:")
132
+
133
+ # Extract everything after "Highlight:"
134
+ match = re.search(r"Highlight: (.+)", result_r)
135
+ if match:
136
+ highlighted_text = match.group(1)
137
+ else:
138
+ highlighted_text="hello world"
139
+ st.session_state.messages.append({"role": "assistant", "content": result["result"]})
140
+ st.chat_message("assistant").write(result_r)
141
+ #display_typing_effect(st.chat_message("assistant"), result_r)
142
+ #st.markdown(result['source_documents'][0])
143
+ #st.markdown(result['source_documents'][1])
144
+ #st.markdown(result['source_documents'][2])
145
+ #st.markdown(result['source_documents'][3])
146
+ #st.markdown(result['source_documents'][4])
147
+
148
+
149
+ st.image("https://cataas.com/cat/says/"+highlighted_text)