Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_chat import message
|
4 |
+
from langchain_openai import OpenAIEmbeddings
|
5 |
+
from pinecone import Pinecone
|
6 |
+
import time
|
7 |
+
from langchain_pinecone.vectorstores import Pinecone as PineconeVectorStore
|
8 |
+
from langchain_core.output_parsers import StrOutputParser
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
10 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
11 |
+
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_community.chat_models.fireworks import ChatFireworks
|
13 |
+
from langchain_groq import ChatGroq
|
14 |
+
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
|
15 |
+
from langchain.memory import ConversationBufferMemory
|
16 |
+
from langchain_core.runnables import RunnableLambda
|
17 |
+
from operator import itemgetter
|
18 |
+
|
19 |
+
# Streamlit App Configuration
|
20 |
+
st.set_page_config(page_title="Docu-Help", page_icon="🟩")
|
21 |
+
st.markdown("<h1 style='text-align: center;'>Ask away:</h1>", unsafe_allow_html=True)
|
22 |
+
|
23 |
+
# Read API keys from environment variables
|
24 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
25 |
+
PINE_API_KEY = os.getenv("PINE_API_KEY")
|
26 |
+
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
|
27 |
+
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
|
28 |
+
LANGCHAIN_TRACING_V2 = 'true'
|
29 |
+
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
|
30 |
+
LANGCHAIN_PROJECT = "docu-help"
|
31 |
+
|
32 |
+
# Sidebar for model selection and Pinecone index name input
|
33 |
+
st.sidebar.title("Sidebar")
|
34 |
+
model_name = st.sidebar.radio("Choose a model:", ("gpt-3.5-turbo-1106", "gpt-4-0125-preview", "mixtral-fireworks", "mixtral-groq"))
|
35 |
+
openai_api_key2 = st.sidebar.text_input("Enter OpenAI Key: ")
|
36 |
+
groq_api_key = st.sidebar.text_input("Groq API Key: ")
|
37 |
+
pinecone_index_name = st.sidebar.text_input("Enter Pinecone Index Name:")
|
38 |
+
namespace_name = st.sidebar.text_input("Namespace:")
|
39 |
+
|
40 |
+
# Initialize session state variables if they don't exist
|
41 |
+
if 'generated' not in st.session_state:
|
42 |
+
st.session_state['generated'] = []
|
43 |
+
|
44 |
+
if 'past' not in st.session_state:
|
45 |
+
st.session_state['past'] = []
|
46 |
+
|
47 |
+
if 'messages' not in st.session_state:
|
48 |
+
st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
|
49 |
+
|
50 |
+
if 'total_cost' not in st.session_state:
|
51 |
+
st.session_state['total_cost'] = 0.0
|
52 |
+
|
53 |
+
def refresh_text():
|
54 |
+
with response_container:
|
55 |
+
for i in range(len(st.session_state['past'])):
|
56 |
+
try:
|
57 |
+
user_message_content = st.session_state["past"][i]
|
58 |
+
message = st.chat_message("user")
|
59 |
+
message.write(user_message_content)
|
60 |
+
except:
|
61 |
+
print("Past error")
|
62 |
+
|
63 |
+
try:
|
64 |
+
ai_message_content = st.session_state["generated"][i]
|
65 |
+
message = st.chat_message("assistant")
|
66 |
+
message.write(ai_message_content)
|
67 |
+
except:
|
68 |
+
print("Generated Error")
|
69 |
+
|
70 |
+
# Function to generate a response using App 2's functionality
|
71 |
+
def generate_response(prompt):
|
72 |
+
st.session_state['messages'].append({"role": "user", "content": prompt})
|
73 |
+
embed = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)
|
74 |
+
|
75 |
+
pc = Pinecone(api_key=PINE_API_KEY)
|
76 |
+
index = pc.Index(pinecone_index_name)
|
77 |
+
time.sleep(1) # Ensure index is ready
|
78 |
+
index.describe_index_stats()
|
79 |
+
|
80 |
+
vectorstore = PineconeVectorStore(index, embed, "text", namespace=namespace_name)
|
81 |
+
retriever = vectorstore.as_retriever()
|
82 |
+
|
83 |
+
template = """You are an expert software developer who specializes in APIs. Answer the user's question based only on the following context:
|
84 |
+
{context}
|
85 |
+
|
86 |
+
Chat History:
|
87 |
+
{chat_history}
|
88 |
+
|
89 |
+
Question: {question}
|
90 |
+
"""
|
91 |
+
prompt_template = ChatPromptTemplate.from_template(template)
|
92 |
+
|
93 |
+
if model_name == "mixtral-fireworks":
|
94 |
+
chat_model = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
95 |
+
elif model_name == "mixtral-groq":
|
96 |
+
chat_model = ChatGroq(temperature=0, groq_api_key=groq_api_key, model_name="mixtral-8x7b-32768")
|
97 |
+
else:
|
98 |
+
chat_model = ChatOpenAI(temperature=0, model=model_name, openai_api_key=openai_api_key2)
|
99 |
+
|
100 |
+
memory = ConversationBufferMemory(
|
101 |
+
return_messages=True, output_key="answer", input_key="question"
|
102 |
+
)
|
103 |
+
|
104 |
+
# Loading the previous chat messages into memory
|
105 |
+
for i in range(len(st.session_state['generated'])):
|
106 |
+
# Replaced "Answer: " with "" to stop the model from learning to add "Answer: " to the beginning by itself
|
107 |
+
memory.save_context({"question": st.session_state["past"][i]}, {"answer": st.session_state["generated"][i].replace("Answer: ", "")})
|
108 |
+
|
109 |
+
# Prints the memory that the model will be using
|
110 |
+
print(f"Memory: {memory.load_memory_variables({})}")
|
111 |
+
|
112 |
+
rag_chain = (
|
113 |
+
RunnablePassthrough.assign(context=(lambda x: x["context"]), chat_history=lambda x: get_buffer_string(x["chat_history"]))
|
114 |
+
| prompt_template
|
115 |
+
| chat_model
|
116 |
+
| StrOutputParser()
|
117 |
+
)
|
118 |
+
|
119 |
+
rag_chain_with_source = RunnableParallel(
|
120 |
+
{"context": retriever, "question": RunnablePassthrough(), "chat_history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")}
|
121 |
+
).assign(answer=rag_chain)
|
122 |
+
|
123 |
+
# Function that extracts the individual tokens from the output of the model
|
124 |
+
def make_stream():
|
125 |
+
sources = []
|
126 |
+
st.session_state['generated'].append("Answer: ")
|
127 |
+
yield st.session_state['generated'][-1]
|
128 |
+
|
129 |
+
for chunk in rag_chain_with_source.stream(prompt):
|
130 |
+
|
131 |
+
if list(chunk.keys())[0] == 'answer':
|
132 |
+
st.session_state['generated'][-1] += chunk['answer']
|
133 |
+
yield chunk['answer']
|
134 |
+
|
135 |
+
elif list(chunk.keys())[0] == 'context':
|
136 |
+
# sources = chunk['context']
|
137 |
+
sources = [doc.metadata['source'] for doc in chunk['context']]
|
138 |
+
|
139 |
+
sources_txt = "\n\nSources:\n" + "\n".join(sources)
|
140 |
+
st.session_state['generated'][-1] += sources_txt
|
141 |
+
yield sources_txt
|
142 |
+
|
143 |
+
# Sending the message as a stream using the function above
|
144 |
+
print("Running the response streamer...")
|
145 |
+
with response_container:
|
146 |
+
message = st.chat_message("assistant")
|
147 |
+
my_generator = make_stream()
|
148 |
+
message.write_stream(my_generator)
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
formatted_response = st.session_state['generated'][-1]
|
153 |
+
|
154 |
+
#response = rag_chain_with_source.invoke(prompt)
|
155 |
+
|
156 |
+
#sources = [doc.metadata['source'] for doc in response['context']]
|
157 |
+
|
158 |
+
#answer = response['answer'] # Extracting the 'answer' part
|
159 |
+
|
160 |
+
#formatted_response = f"Answer: {answer}\n\nSources:\n" + "\n".join(sources)
|
161 |
+
|
162 |
+
st.session_state['messages'].append({"role": "assistant", "content": formatted_response})
|
163 |
+
|
164 |
+
return formatted_response
|
165 |
+
|
166 |
+
# Container for chat history and text box
|
167 |
+
response_container = st.container()
|
168 |
+
container = st.container()
|
169 |
+
|
170 |
+
# Implementing chat input as opposed to a form because chat_input stays locked at the bottom
|
171 |
+
if prompt := st.chat_input("Ask a question..."):
|
172 |
+
# I moved reponse here because, for some reason, I get an error if I only have an if statement for user_input later...
|
173 |
+
st.session_state['past'].append(prompt)
|
174 |
+
refresh_text()
|
175 |
+
|
176 |
+
response = generate_response(prompt)
|