hanoch.rahimi@gmail commited on
Commit
0c14e18
·
1 Parent(s): 9f89884

fix history log

Browse files
Files changed (3) hide show
  1. app.py +29 -22
  2. openai_utils.py +27 -22
  3. utils.py +6 -1
app.py CHANGED
@@ -19,19 +19,21 @@ import openai_utils as oai
19
  from streamlit_extras.stylable_container import stylable_container
20
 
21
 
22
- OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
23
- model_name = 'text-embedding-ada-002'
24
 
25
- embed = OpenAIEmbeddings(
26
- model=model_name,
27
- openai_api_key=OPENAI_API_KEY
28
- )
 
 
29
 
30
  st.set_page_config(
31
  layout="wide",
32
  initial_sidebar_state="collapsed",
33
  page_title="RaizedAI Startup Discovery Assistant",
34
- page_icon=":robot:"
35
  )
36
 
37
  COUNTRIES_FN="data/countries.csv"
@@ -69,6 +71,9 @@ st.session_state.openai_client = oai.get_client()
69
  retriever, tokenizer = init_models()
70
  st.session_state.retriever = retriever
71
 
 
 
 
72
  #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
73
 
74
 
@@ -124,9 +129,6 @@ def card(company_id, name, description, score, data_type, region, country, metad
124
  #print(f" markdown for {company_id}\n{markdown}")
125
  return markdown
126
 
127
-
128
-
129
-
130
  def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
131
 
132
  #Summarize the results
@@ -135,7 +137,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
135
  # Create a summarized report focusing on the top3 companies.
136
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
137
  # """
138
- col_content, col_sidepanel = st.columns([4, 1], gap="small")
139
  if report_type=="guided":
140
  prompt_txt = utils.query_finetune_prompt + """
141
  User query: {query}
@@ -143,7 +145,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
143
  prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
144
  prompt = prompt_template.format(query = query)
145
  m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
146
-
147
  print(f"Keywords: {m_text}")
148
  results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
149
 
@@ -168,13 +170,20 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
168
  elif report_type=="assistant":
169
  #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
170
  #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
171
- m_text = oai.call_openai(query, engine=openai_model, temp=0, top_p=1.0)
172
  results = st.session_state.db_search_results
173
- with col_content:
174
- with st.chat_message("assistant"): #, "assets/raized_logo.webp"):
175
- st.write(m_text)
176
- st.session_state.messages.append({"role": "user", "content": query})
177
- st.session_state.messages.append({"role": "system", "content": m_text})
 
 
 
 
 
 
 
178
 
179
  else:
180
  st.session_state.new_conversation = False
@@ -259,14 +268,14 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
259
  pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
260
 
261
  if len(pins)>0:
262
- with col_sidepanel:
263
  st.map(pins)
264
  #st.markdown(list_html, unsafe_allow_html=True)
265
 
266
  df = pd.DataFrame.from_dict(carddict, orient="columns")
267
 
268
  if len(df)>0:
269
- with col_content:
270
  st.dataframe(df,
271
  hide_index=False,
272
  column_config ={
@@ -419,8 +428,6 @@ if utils.check_password():
419
  # prompt = "guided"
420
  # else:
421
  # prompt = ""
422
- with st.chat_message("user"):
423
- st.write(query)
424
  #oai.start_conversation()
425
  i = assistant_id.index("|||")
426
  st.session_state.assistant_id = assistant_id[:i]
 
19
  from streamlit_extras.stylable_container import stylable_container
20
 
21
 
22
+ # OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
23
+ #model_name = 'text-embedding-ada-002'
24
 
25
+ # embed = OpenAIEmbeddings(
26
+ # model=model_name,
27
+ # openai_api_key=OPENAI_API_KEY
28
+ # )
29
+
30
+ #"🤖",
31
 
32
  st.set_page_config(
33
  layout="wide",
34
  initial_sidebar_state="collapsed",
35
  page_title="RaizedAI Startup Discovery Assistant",
36
+ #page_icon=":robot:"
37
  )
38
 
39
  COUNTRIES_FN="data/countries.csv"
 
71
  retriever, tokenizer = init_models()
72
  st.session_state.retriever = retriever
73
 
74
+ # AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
75
+ # "user": "👩‍⚖️"}
76
+
77
  #st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
78
 
79
 
 
129
  #print(f" markdown for {company_id}\n{markdown}")
130
  return markdown
131
 
 
 
 
132
  def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
133
 
134
  #Summarize the results
 
137
  # Create a summarized report focusing on the top3 companies.
138
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
139
  # """
140
+ content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
141
  if report_type=="guided":
142
  prompt_txt = utils.query_finetune_prompt + """
143
  User query: {query}
 
145
  prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
146
  prompt = prompt_template.format(query = query)
147
  m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
148
+
149
  print(f"Keywords: {m_text}")
150
  results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
151
 
 
170
  elif report_type=="assistant":
171
  #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
172
  #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
173
+ messages = oai.call_assistant(query, engine=openai_model)
174
  results = st.session_state.db_search_results
175
+ with content_container:
176
+ for message in list(messages)[::-1]:
177
+ with st.chat_message(name = message.role):
178
+ st.write(message.content[0].text.value)
179
+ # if message.role == "assistant":
180
+ # with st.chat_message(name = message.role, avatar = st.image("resources/raized_logo.png")):
181
+ # st.write(message.content[0].text.value)
182
+ # else:
183
+ # with st.chat_message(name = message.role):
184
+ # st.write(message.content[0].text.value)
185
+ # st.session_state.messages.append({"role": "user", "content": query})
186
+ # st.session_state.messages.append({"role": "system", "content": m_text})
187
 
188
  else:
189
  st.session_state.new_conversation = False
 
268
  pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
269
 
270
  if len(pins)>0:
271
+ with st.expander("Map view"):
272
  st.map(pins)
273
  #st.markdown(list_html, unsafe_allow_html=True)
274
 
275
  df = pd.DataFrame.from_dict(carddict, orient="columns")
276
 
277
  if len(df)>0:
278
+ with content_container:
279
  st.dataframe(df,
280
  hide_index=False,
281
  column_config ={
 
428
  # prompt = "guided"
429
  # else:
430
  # prompt = ""
 
 
431
  #oai.start_conversation()
432
  i = assistant_id.index("|||")
433
  st.session_state.assistant_id = assistant_id[:i]
openai_utils.py CHANGED
@@ -93,31 +93,36 @@ def wait_for_response(thread, run):
93
  print(f"Run status: {run_status.status}")
94
  return run_status
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
98
  if st.session_state.report_type=="assistant":
99
- try:
100
- thread = st.session_state.assistant_thread
101
- assistant_id = st.session_state.assistant_id
102
- message = st.session_state.openai_client.beta.threads.messages.create(
103
- thread.id,
104
- role="user",
105
- content=prompt,
106
- )
107
- run = st.session_state.openai_client.beta.threads.runs.create(
108
- thread_id=thread.id,
109
- assistant_id=assistant_id,
110
- )
111
- messages = wait_for_response(thread, run)
112
-
113
- print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
114
- text = ""
115
- for message in messages:
116
- text = text + "\n" + message.content[0].text.value
117
- return text
118
- except Exception as e:
119
- #except openai.error.OpenAIError as e:
120
- print(f"An error occurred: {str(e)}")
121
  else:
122
  try:
123
  response = st.session_state.openai_client.chat.completions.create(
 
93
  print(f"Run status: {run_status.status}")
94
  return run_status
95
 
96
+ def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
97
+ try:
98
+ thread = st.session_state.assistant_thread
99
+ assistant_id = st.session_state.assistant_id
100
+ message = st.session_state.openai_client.beta.threads.messages.create(
101
+ thread.id,
102
+ role="user",
103
+ content=query,
104
+ )
105
+ run = st.session_state.openai_client.beta.threads.runs.create(
106
+ thread_id=thread.id,
107
+ assistant_id=assistant_id,
108
+ )
109
+ messages = wait_for_response(thread, run)
110
+
111
+ print(f"====================\nOpen AI response\n {str(messages)[:1000]}\n====================\n")
112
+ return messages
113
+ # text = ""
114
+ # for message in messages:
115
+ # print(message)
116
+ # text = text + "\n" + message.content[0].text.value
117
+ # return text
118
+ except Exception as e:
119
+ #except openai.error.OpenAIError as e:
120
+ print(f"An error occurred: {str(e)}")
121
+
122
 
123
  def call_openai(prompt, engine="gpt-3.5-turbo", temp=0, top_p=1.0, max_tokens=4048):
124
  if st.session_state.report_type=="assistant":
125
+ raise Exception("use call_assistant instead of call_openai")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
  try:
128
  response = st.session_state.openai_client.chat.completions.create(
utils.py CHANGED
@@ -15,7 +15,12 @@ import streamlit as st
15
  # response = gcp_client.access_secret_version(request={"name": version.name})
16
 
17
  def get_variable(name):
18
- return os.getenv(name, st.secrets[name])
 
 
 
 
 
19
 
20
  OPENAI_API_KEY = get_variable("OPENAI_API_KEY") # app.pinecone.io
21
  OPENAI_ORGANIZATION_ID = get_variable("OPENAI_ORGANIZATION_ID")
 
15
  # response = gcp_client.access_secret_version(request={"name": version.name})
16
 
17
  def get_variable(name):
18
+ res = ""
19
+ try:
20
+ res = os.getenv(name, st.secrets[name])
21
+ except Exception as e:
22
+ pass
23
+ return res
24
 
25
  OPENAI_API_KEY = get_variable("OPENAI_API_KEY") # app.pinecone.io
26
  OPENAI_ORGANIZATION_ID = get_variable("OPENAI_ORGANIZATION_ID")