hanoch.rahimi@gmail commited on
Commit
ba1f3e2
·
1 Parent(s): 4c4a1d7

increase number retreived items, prevent resending the same query

Browse files
Files changed (2) hide show
  1. app.py +11 -4
  2. openai_utils.py +6 -0
app.py CHANGED
@@ -185,6 +185,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
185
  #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
186
  #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
187
  messages = oai.call_assistant(query, engine=openai_model)
 
188
  results = st.session_state.db_search_results
189
  if not messages is None:
190
  with content_container:
@@ -292,6 +293,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
292
  df = pd.DataFrame.from_dict(carddict, orient="columns")
293
 
294
  if len(df)>0:
 
295
  with content_container:
296
  st.dataframe(df,
297
  hide_index=False,
@@ -305,6 +307,7 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
305
  "business_model": st.column_config.TextColumn(label="Business model")
306
  },
307
  use_container_width=True)
 
308
 
309
 
310
  def query_sent():
@@ -335,6 +338,11 @@ def render_history():
335
  if not 'submitted_query' in st.session_state:
336
  st.session_state.submitted_query = ''
337
 
 
 
 
 
 
338
  if utils.check_password():
339
 
340
  st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
@@ -420,18 +428,18 @@ if utils.check_password():
420
  report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
421
  #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
422
  assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
423
- default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
424
- clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
425
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
426
  #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
427
  #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
428
  #scrape_boost = st.number_input('Web to API content ratio', value=1.)
429
- top_k = st.number_input('# Top Results', value=20)
430
  is_debug = st.checkbox("Debug output", value = False, key="debug")
431
  openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
432
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
433
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
434
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
 
 
435
 
436
  if not "assistant_thread" in st.session_state:
437
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
@@ -454,7 +462,6 @@ if utils.check_password():
454
  st.session_state.index_namespace = index_namespace
455
  st.session_state.region = region_selectbox
456
  st.session_state.country = countries_selectbox
457
- #st.session_state.user_query = ''
458
  run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
459
  else:
460
  st.session_state.new_conversation = False
 
185
  #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
186
  #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
187
  messages = oai.call_assistant(query, engine=openai_model)
188
+ st.session_state.messages = messages
189
  results = st.session_state.db_search_results
190
  if not messages is None:
191
  with content_container:
 
293
  df = pd.DataFrame.from_dict(carddict, orient="columns")
294
 
295
  if len(df)>0:
296
+ df.index += 1
297
  with content_container:
298
  st.dataframe(df,
299
  hide_index=False,
 
307
  "business_model": st.column_config.TextColumn(label="Business model")
308
  },
309
  use_container_width=True)
310
+ st.session_state.last_user_query = query
311
 
312
 
313
  def query_sent():
 
338
  if not 'submitted_query' in st.session_state:
339
  st.session_state.submitted_query = ''
340
 
341
+ if not 'messages' in st.session_state:
342
+ st.session_state.messages = []
343
+ if not 'last_user_query' in st.session_state:
344
+ st.session_state.last_user_query = ''
345
+
346
  if utils.check_password():
347
 
348
  st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
 
428
  report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
429
  #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
430
  assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
 
 
431
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
432
  #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
433
  #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
434
  #scrape_boost = st.number_input('Web to API content ratio', value=1.)
435
+ top_k = st.number_input('# Top Results', value=30)
436
  is_debug = st.checkbox("Debug output", value = False, key="debug")
437
  openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
438
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
439
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
440
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
441
+ default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
442
+ clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
443
 
444
  if not "assistant_thread" in st.session_state:
445
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
 
462
  st.session_state.index_namespace = index_namespace
463
  st.session_state.region = region_selectbox
464
  st.session_state.country = countries_selectbox
 
465
  run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
466
  else:
467
  st.session_state.new_conversation = False
openai_utils.py CHANGED
@@ -65,6 +65,7 @@ def wait_for_response(thread, run):
65
  elif run_status.status == 'requires_action':
66
  required_action = run_status.required_action
67
  if required_action.type == 'submit_tool_outputs':
 
68
  outputs = {}
69
  for tool_call in required_action.submit_tool_outputs.tool_calls:
70
  if tool_call.function.name =="getListOfCompanies":
@@ -107,6 +108,11 @@ def wait_for_response(thread, run):
107
 
108
 
109
  def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
 
 
 
 
 
110
  try:
111
  thread = st.session_state.assistant_thread
112
  assistant_id = st.session_state.assistant_id
 
65
  elif run_status.status == 'requires_action':
66
  required_action = run_status.required_action
67
  if required_action.type == 'submit_tool_outputs':
68
+ print(f"Requires tool outputs: {required_action}")
69
  outputs = {}
70
  for tool_call in required_action.submit_tool_outputs.tool_calls:
71
  if tool_call.function.name =="getListOfCompanies":
 
108
 
109
 
110
  def call_assistant(query, engine="gpt-3.5-turbo"): #, temp=0, top_p=1.0, max_tokens=4048):
111
+ #Prevent re sending the last message over and over
112
+ print(f"Last query {st.session_state.last_user_query}, current query {query}")
113
+ if st.session_state.last_user_query == query:
114
+ report_error(f"That query '{query}' was just sent. We don't send the same query twice in a row. ")
115
+ return st.session_state.messages
116
  try:
117
  thread = st.session_state.assistant_thread
118
  assistant_id = st.session_state.assistant_id