seansullivan commited on
Commit
e486ecf
·
verified ·
1 Parent(s): 31e81b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from streamlit_chat import message
4
+ from langchain_openai import OpenAIEmbeddings
5
+ from pinecone import Pinecone
6
+ import time
7
+ from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_community.chat_models.fireworks import ChatFireworks
13
+ from langchain_groq import ChatGroq
14
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
15
+ from langchain.memory import ConversationBufferMemory
16
+ from langchain_core.runnables import RunnableLambda
17
+ from operator import itemgetter
18
+
19
+ # Streamlit App Configuration
20
+ st.set_page_config(page_title="Docu-Help", page_icon="🟩")
21
+ st.markdown("<h1 style='text-align: center;'>Ask away:</h1>", unsafe_allow_html=True)
22
+
23
+ # Read API keys from environment variables
24
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
25
+ PINE_API_KEY = os.getenv("PINE_API_KEY")
26
+ FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
27
+ LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
28
+ LANGCHAIN_TRACING_V2 = 'true'
29
+ LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
30
+ LANGCHAIN_PROJECT = "docu-help"
31
+
32
+ # Sidebar for model selection and Pinecone index name input
33
+ st.sidebar.title("Sidebar")
34
+ model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "mixtral-fireworks", "mixtral-groq"))
35
+ openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ")
36
+ groq_api_key = st.sidebar.text_input("Groq API Key: ")
37
+ pinecone_index_name = st.sidebar.text_input("Enter Pinecone Index Name:")
38
+ namespace_name = st.sidebar.text_input("Namespace:")
39
+
40
+ # Initialize session state variables if they don't exist
41
+ if 'generated' not in st.session_state:
42
+ st.session_state['generated'] = []
43
+
44
+ if 'past' not in st.session_state:
45
+ st.session_state['past'] = []
46
+
47
+ if 'messages' not in st.session_state:
48
+ st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
49
+
50
+ if 'total_cost' not in st.session_state:
51
+ st.session_state['total_cost'] = 0.0
52
+
53
+ def refresh_text():
54
+ with response_container:
55
+ for i in range(len(st.session_state['past'])):
56
+ try:
57
+ user_message_content = st.session_state["past"][i]
58
+ message = st.chat_message("user")
59
+ message.write(user_message_content)
60
+ except:
61
+ print("Past error")
62
+
63
+ try:
64
+ ai_message_content = st.session_state["generated"][i]
65
+ message = st.chat_message("assistant")
66
+ message.write(ai_message_content)
67
+ except:
68
+ print("Generated Error")
69
+
70
+ # Function to generate a response using App 2's functionality
71
+ def generate_response(prompt):
72
+ st.session_state['messages'].append({"role": "user", "content": prompt})
73
+ embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)
74
+
75
+ pc = Pinecone(api_key=PINE_API_KEY)
76
+ index = pc.Index(pinecone_index_name)
77
+ time.sleep(1) # Ensure index is ready
78
+ index.describe_index_stats()
79
+
80
+ vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name)
81
+ retriever = vectorstore.as_retriever()
82
+
83
+ template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context:
84
+ {context}
85
+
86
+ Chat History:
87
+ {chat_history}
88
+
89
+ Question: {question}
90
+ """
91
+ prompt_template = ChatPromptTemplate.from_template(template)
92
+
93
+ if model_name == "mixtral-fireworks":
94
+ chat_model = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
95
+ elif model_name == "mixtral-groq":
96
+ chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768")
97
+ else:
98
+ chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2)
99
+
100
+ memory = ConversationBufferMemory(
101
+ return_messages=True, output_key="answer", input_key="question"
102
+ )
103
+
104
+ # Loading the previous chat messages into memory
105
+ for i in range(len(st.session_state['generated'])):
106
+ # Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself
107
+ memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")})
108
+
109
+ # Prints the memory that the model will be using
110
+ print(f"Memory: {memory.load_memory_variables({})}")
111
+
112
+ rag_chain = (
113
+ RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"]))
114
+ | prompt_template
115
+ | chat_model
116
+ | StrOutputParser()
117
+ )
118
+
119
+ rag_chain_with_source = RunnableParallel(
120
+ {"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")}
121
+ ).assign(answer=rag_chain)
122
+
123
+ # Function that extracts the individual tokens from the output of the model
124
+ def make_stream():
125
+ sources = []
126
+ st.session_state['generated'].append("Answer: ")
127
+ yield st.session_state['generated'][-1]
128
+
129
+ for chunk in rag_chain_with_source.stream(prompt):
130
+
131
+ if list(chunk.keys())[0] == 'answer':
132
+ st.session_state['generated'][-1] += chunk['answer']
133
+ yield chunk['answer']
134
+
135
+ elif list(chunk.keys())[0] == 'context':
136
+ # sources = chunk['context']
137
+ sources = [doc.metadata['source'] for doc in chunk['context']]
138
+
139
+ sources_txt = "\n\nSources:\n" + "\n".join(sources)
140
+ st.session_state['generated'][-1] += sources_txt
141
+ yield sources_txt
142
+
143
+ # Sending the message as a stream using the function above
144
+ print("Running the response streamer...")
145
+ with response_container:
146
+ message = st.chat_message("assistant")
147
+ my_generator = make_stream()
148
+ message.write_stream(my_generator)
149
+
150
+
151
+
152
+ formatted_response = st.session_state['generated'][-1]
153
+
154
+ #response = rag_chain_with_source.invoke(prompt)
155
+
156
+ #sources = [doc.metadata['source'] for doc in response['context']]
157
+
158
+ #answer = response['answer'] # Extracting the 'answer' part
159
+
160
+ #formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources)
161
+
162
+ st.session_state['messages'].append({"role": "assistant", "content": formatted_response})
163
+
164
+ return formatted_response
165
+
166
+ # Container for chat history and text box
167
+ response_container = st.container()
168
+ container = st.container()
169
+
170
+ # Implementing chat input as opposed to a form because chat_input stays locked at the bottom
171
+ if prompt := st.chat_input("Ask a question..."):
172
+ # I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later...
173
+ st.session_state['past'].append(prompt)
174
+ refresh_text()
175
+
176
+ response = generate_response(prompt)