Spaces:
Running
Running
hanoch.rahimi@gmail
commited on
Commit
·
0c14e18
1
Parent(s):
9f89884
fix history log
Browse files- app.py +29 -22
- openai_utils.py +27 -22
- utils.py +6 -1
app.py
CHANGED
@@ -19,19 +19,21 @@ import openai_utils as oai
|
|
19 |
from streamlit_extras.stylable_container import stylable_container
|
20 |
|
21 |
|
22 |
-
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
|
23 |
-
model_name = 'text-embedding-ada-002'
|
24 |
|
25 |
-
embed = OpenAIEmbeddings(
|
26 |
-
|
27 |
-
|
28 |
-
)
|
|
|
|
|
29 |
|
30 |
st.set_page_config(
|
31 |
layout="wide",
|
32 |
initial_sidebar_state="collapsed",
|
33 |
page_title="RaizedAI Startup Discovery Assistant",
|
34 |
-
page_icon=":robot:"
|
35 |
)
|
36 |
|
37 |
COUNTRIES_FN="data/countries.csv"
|
@@ -69,6 +71,9 @@ st.session_state.openai_client = oai.get_client()
|
|
69 |
retriever, tokenizer = init_models()
|
70 |
st.session_state.retriever = retriever
|
71 |
|
|
|
|
|
|
|
72 |
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
|
73 |
|
74 |
|
@@ -124,9 +129,6 @@ def card(company_id, name, description, score, data_type, region, country, metad
|
|
124 |
#print(f" markdown for {company_id}\n{markdown}")
|
125 |
return markdown
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
|
131 |
|
132 |
#Summarize the results
|
@@ -135,7 +137,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
135 |
# Create a summarized report focusing on the top3 companies.
|
136 |
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
|
137 |
# """
|
138 |
-
|
139 |
if report_type=="guided":
|
140 |
prompt_txt = utils.query_finetune_prompt + """
|
141 |
User query: {query}
|
@@ -143,7 +145,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
143 |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
|
144 |
prompt = prompt_template.format(query = query)
|
145 |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
|
146 |
-
|
147 |
print(f"Keywords: {m_text}")
|
148 |
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
|
149 |
|
@@ -168,13 +170,20 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
168 |
elif report_type=="assistant":
|
169 |
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
|
170 |
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
|
171 |
-
|
172 |
results = st.session_state.db_search_results
|
173 |
-
with
|
174 |
-
|
175 |
-
st.
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
else:
|
180 |
st.session_state.new_conversation = False
|
@@ -259,14 +268,14 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
259 |
pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
|
260 |
|
261 |
if len(pins)>0:
|
262 |
-
with
|
263 |
st.map(pins)
|
264 |
#st.markdown(list_html, unsafe_allow_html=True)
|
265 |
|
266 |
df = pd.DataFrame.from_dict(carddict, orient="columns")
|
267 |
|
268 |
if len(df)>0:
|
269 |
-
with
|
270 |
st.dataframe(df,
|
271 |
hide_index=False,
|
272 |
column_config ={
|
@@ -419,8 +428,6 @@ if utils.check_password():
|
|
419 |
# prompt = "guided"
|
420 |
# else:
|
421 |
# prompt = ""
|
422 |
-
with st.chat_message("user"):
|
423 |
-
st.write(query)
|
424 |
#oai.start_conversation()
|
425 |
i = assistant_id.index("|||")
|
426 |
st.session_state.assistant_id = assistant_id[:i]
|
|
|
19 |
from streamlit_extras.stylable_container import stylable_container
|
20 |
|
21 |
|
22 |
+
# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
|
23 |
+
#model_name = 'text-embedding-ada-002'
|
24 |
|
25 |
+
# embed = OpenAIEmbeddings(
|
26 |
+
# model=model_name,
|
27 |
+
# openai_api_key=OPENAI_API_KEY
|
28 |
+
# )
|
29 |
+
|
30 |
+
#"🤖",
|
31 |
|
32 |
st.set_page_config(
|
33 |
layout="wide",
|
34 |
initial_sidebar_state="collapsed",
|
35 |
page_title="RaizedAI Startup Discovery Assistant",
|
36 |
+
#page_icon=":robot:"
|
37 |
)
|
38 |
|
39 |
COUNTRIES_FN="data/countries.csv"
|
|
|
71 |
retriever, tokenizer = init_models()
|
72 |
st.session_state.retriever = retriever
|
73 |
|
74 |
+
# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
|
75 |
+
# "user": "👩⚖️"}
|
76 |
+
|
77 |
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
|
78 |
|
79 |
|
|
|
129 |
#print(f" markdown for {company_id}\n{markdown}")
|
130 |
return markdown
|
131 |
|
|
|
|
|
|
|
132 |
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
|
133 |
|
134 |
#Summarize the results
|
|
|
137 |
# Create a summarized report focusing on the top3 companies.
|
138 |
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
|
139 |
# """
|
140 |
+
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
|
141 |
if report_type=="guided":
|
142 |
prompt_txt = utils.query_finetune_prompt + """
|
143 |
User query: {query}
|
|
|
145 |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
|
146 |
prompt = prompt_template.format(query = query)
|
147 |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
|
148 |
+
|
149 |
print(f"Keywords: {m_text}")
|
150 |
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
|
151 |
|
|
|
170 |
elif report_type=="assistant":
|
171 |
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
|
172 |
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
|
173 |
+
messages = oai.call_assistant(query, engine=openai_model)
|
174 |
results = st.session_state.db_search_results
|
175 |
+
with content_container:
|
176 |
+
for message in list(messages)[::-1]:
|
177 |
+
with st.chat_message(name = message.role):
|
178 |
+
st.write(message.content[0].text.value)
|
179 |
+
# if message.role == "assistant":
|
180 |
+
# with st.chat_message(name = message.role, avatar = st.image("resources/raized_logo.png")):
|
181 |
+
# st.write(message.content[0].text.value)
|
182 |
+
# else:
|
183 |
+
# with st.chat_message(name = message.role):
|
184 |
+
# st.write(message.content[0].text.value)
|
185 |
+
# st.session_state.messages.append({"role": "user", "content": query})
|
186 |
+
# st.session_state.messages.append({"role": "system", "content": m_text})
|
187 |
|
188 |
else:
|
189 |
st.session_state.new_conversation = False
|
|
|
268 |
pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
|
269 |
|
270 |
if len(pins)>0:
|
271 |
+
with st.expander("Map view"):
|
272 |
st.map(pins)
|
273 |
#st.markdown(list_html, unsafe_allow_html=True)
|
274 |
|
275 |
df = pd.DataFrame.from_dict(carddict, orient="columns")
|
276 |
|
277 |
if len(df)>0:
|
278 |
+
with content_container:
|
279 |
st.dataframe(df,
|
280 |
hide_index=False,
|
281 |
column_config ={
|
|
|
428 |
# prompt = "guided"
|
429 |
# else:
|
430 |
# prompt = ""
|
|
|
|
|
431 |
#oai.start_conversation()
|
432 |
i = assistant_id.index("|||")
|
433 |
st.session_state.assistant_id = assistant_id[:i]
|
openai_utils.py
CHANGED
@@ -93,31 +93,36 @@ def wait_for_response(thread, run):
|
|
93 |
print(f"Run status: {run_status.status}")
|
94 |
return run_status
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
|
98 |
if st.session_state.report_type=="assistant":
|
99 |
-
|
100 |
-
thread = st.session_state.assistant_thread
|
101 |
-
assistant_id = st.session_state.assistant_id
|
102 |
-
message = st.session_state.openai_client.beta.threads.messages.create(
|
103 |
-
thread.id,
|
104 |
-
role="user",
|
105 |
-
content=prompt,
|
106 |
-
)
|
107 |
-
run = st.session_state.openai_client.beta.threads.runs.create(
|
108 |
-
thread_id=thread.id,
|
109 |
-
assistant_id=assistant_id,
|
110 |
-
)
|
111 |
-
messages = wait_for_response(thread, run)
|
112 |
-
|
113 |
-
print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
|
114 |
-
text = ""
|
115 |
-
for message in messages:
|
116 |
-
text = text + "\n" + message.content[0].text.value
|
117 |
-
return text
|
118 |
-
except Exception as e:
|
119 |
-
#except openai.error.OpenAIError as e:
|
120 |
-
print(f"An error occurred: {str(e)}")
|
121 |
else:
|
122 |
try:
|
123 |
response = st.session_state.openai_client.chat.completions.create(
|
|
|
93 |
print(f"Run status: {run_status.status}")
|
94 |
return run_status
|
95 |
|
96 |
+
def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
|
97 |
+
try:
|
98 |
+
thread = st.session_state.assistant_thread
|
99 |
+
assistant_id = st.session_state.assistant_id
|
100 |
+
message = st.session_state.openai_client.beta.threads.messages.create(
|
101 |
+
thread.id,
|
102 |
+
role="user",
|
103 |
+
content=query,
|
104 |
+
)
|
105 |
+
run = st.session_state.openai_client.beta.threads.runs.create(
|
106 |
+
thread_id=thread.id,
|
107 |
+
assistant_id=assistant_id,
|
108 |
+
)
|
109 |
+
messages = wait_for_response(thread, run)
|
110 |
+
|
111 |
+
print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
|
112 |
+
return messages
|
113 |
+
# text = ""
|
114 |
+
# for message in messages:
|
115 |
+
# print(message)
|
116 |
+
# text = text + "\n" + message.content[0].text.value
|
117 |
+
# return text
|
118 |
+
except Exception as e:
|
119 |
+
#except openai.error.OpenAIError as e:
|
120 |
+
print(f"An error occurred: {str(e)}")
|
121 |
+
|
122 |
|
123 |
def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
|
124 |
if st.session_state.report_type=="assistant":
|
125 |
+
raise Exception("use call_assistant instead of call_openai")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
else:
|
127 |
try:
|
128 |
response = st.session_state.openai_client.chat.completions.create(
|
utils.py
CHANGED
@@ -15,7 +15,12 @@ import streamlit as st
|
|
15 |
# response = gcp_client.access_secret_version(request={"name": version.name})
|
16 |
|
17 |
def get_variable(name):
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
OPENAI_API_KEY = get_variable("OPENAI_API_KEY") # app.pinecone.io
|
21 |
OPENAI_ORGANIZATION_ID = get_variable("OPENAI_ORGANIZATION_ID")
|
|
|
15 |
# response = gcp_client.access_secret_version(request={"name": version.name})
|
16 |
|
17 |
def get_variable(name):
|
18 |
+
res = ""
|
19 |
+
try:
|
20 |
+
res = os.getenv(name, st.secrets[name])
|
21 |
+
except Exception as e:
|
22 |
+
pass
|
23 |
+
return res
|
24 |
|
25 |
OPENAI_API_KEY = get_variable("OPENAI_API_KEY") # app.pinecone.io
|
26 |
OPENAI_ORGANIZATION_ID = get_variable("OPENAI_ORGANIZATION_ID")
|