hanoch.rahimi@gmail commited on
Commit
09df805
·
1 Parent(s): 7c3b5b3

ui changes

Browse files
Files changed (1) hide show
  1. app.py +56 -43
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import json
2
 
3
-
4
  from langchain.chains import RetrievalQA
5
  from langchain.embeddings.openai import OpenAIEmbeddings
6
  from langchain.prompts import PromptTemplate
7
  from langchain.vectorstores import Pinecone
8
  import pandas as pd
 
9
  from streamlit.runtime.state import session_state
10
  import openai
11
  import streamlit as st
@@ -45,6 +45,9 @@ country_geo = pd.read_csv(COUNTRIES_FN)
45
  st.session_state.index = utils.init_pinecone()
46
  st.session_state.db_search_results = []
47
 
 
 
 
48
  carddict = {
49
  "name": [],
50
  "company_id": [],
@@ -62,14 +65,21 @@ def init_models():
62
  retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
63
  #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
64
  tokenizer = AutoTokenizer.from_pretrained(model_name)
65
- #vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)
66
- # client.beta.assistants.create(
67
- # instructions=utils.assistant_instructions,
68
- # model="gpt-4-1106-preview",
69
- # tools=[{"type": "code_interpreter"}])
70
  return retriever, tokenizer#, vectorstore
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  st.session_state.openai_client = oai.get_client()
74
  retriever, tokenizer = init_models()
75
  st.session_state.retriever = retriever
@@ -179,15 +189,15 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
179
  with content_container:
180
  for message in list(messages)[::-1]:
181
  if hasattr(message, 'role'):
182
- print(f"\n-----\nMessage: {message}\n")
183
- with st.chat_message(name = message.role):
184
- st.write(message.content[0].text.value)
185
- # if message.role == "assistant":
186
- # with st.chat_message(name = message.role, avatar = st.image("resources/raized_logo.png")):
187
- # st.write(message.content[0].text.value)
188
- # else:
189
- # with st.chat_message(name = message.role):
190
- # st.write(message.content[0].text.value)
191
  # st.session_state.messages.append({"role": "user", "content": query})
192
  # st.session_state.messages.append({"role": "system", "content": m_text})
193
 
@@ -296,6 +306,9 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
296
  use_container_width=True)
297
 
298
 
 
 
 
299
  def render_history():
300
  with st.session_state.history_container:
301
 
@@ -328,7 +341,7 @@ if utils.check_password():
328
  if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
329
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
330
  st.session_state.new_conversation = True
331
- st.session_state.messages = [{"role":"system", "content":"Hello. I'm your startups discovery assistant."}]
332
 
333
  st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
334
 
@@ -366,32 +379,8 @@ if utils.check_password():
366
  unsafe_allow_html=True
367
  )
368
 
369
- tab_search, tab_advanced = st.tabs(["Search", "Settings"])
370
-
371
- assistants = st.session_state.openai_client.beta.assistants.list(
372
- order="desc",
373
- limit="20",
374
- )
375
-
376
- with tab_advanced:
377
- #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
378
- #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
379
- report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
380
- #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
381
- assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
382
- default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
383
- clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
384
- #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
385
- #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
386
- #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
387
- #scrape_boost = st.number_input('Web to API content ratio', value=1.)
388
- top_k = st.number_input('# Top Results', value=20)
389
- is_debug = st.checkbox("Debug output", value = False, key="debug")
390
- openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
391
- index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
392
- liked_companies = st.text_input(label="liked companies", key='liked_companies')
393
- disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
394
-
395
 
396
 
397
  with tab_search:
@@ -423,10 +412,31 @@ if utils.check_password():
423
  #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
424
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  if not "assistant_thread" in st.session_state:
427
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
428
 
429
- if query != "":
 
430
  # if report_type=="standard":
431
  # prompt = default_prompt
432
  # elif report_type=="clustered":
@@ -443,5 +453,8 @@ if utils.check_password():
443
  st.session_state.index_namespace = index_namespace
444
  st.session_state.region = region_selectbox
445
  st.session_state.country = countries_selectbox
 
446
  run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
 
 
447
 
 
1
  import json
2
 
 
3
  from langchain.chains import RetrievalQA
4
  from langchain.embeddings.openai import OpenAIEmbeddings
5
  from langchain.prompts import PromptTemplate
6
  from langchain.vectorstores import Pinecone
7
  import pandas as pd
8
+ from PIL import Image
9
  from streamlit.runtime.state import session_state
10
  import openai
11
  import streamlit as st
 
45
  st.session_state.index = utils.init_pinecone()
46
  st.session_state.db_search_results = []
47
 
48
+ #st.image("resources/raized_logo.png")
49
+ assistant_avatar = Image.open('resources/raized_logo.png')
50
+
51
  carddict = {
52
  "name": [],
53
  "company_id": [],
 
65
  retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
66
  #reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
67
  tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+
 
 
 
 
69
  return retriever, tokenizer#, vectorstore
70
 
71
 
72
+ @st.cache_resource
73
+ def init_openai():
74
+
75
+ assistants = st.session_state.openai_client.beta.assistants.list(
76
+ order="desc",
77
+ limit="20",
78
+ )
79
+ return assistants
80
+
81
+ assistants = init_openai()
82
+
83
  st.session_state.openai_client = oai.get_client()
84
  retriever, tokenizer = init_models()
85
  st.session_state.retriever = retriever
 
189
  with content_container:
190
  for message in list(messages)[::-1]:
191
  if hasattr(message, 'role'):
192
+ # print(f"\n-----\nMessage: {message}\n")
193
+ # with st.chat_message(name = message.role):
194
+ # st.write(message.content[0].text.value)
195
+ if message.role == "assistant":
196
+ with st.chat_message(name = message.role, avatar = assistant_avatar):
197
+ st.write(message.content[0].text.value)
198
+ else:
199
+ with st.chat_message(name = message.role):
200
+ st.write(message.content[0].text.value)
201
  # st.session_state.messages.append({"role": "user", "content": query})
202
  # st.session_state.messages.append({"role": "system", "content": m_text})
203
 
 
306
  use_container_width=True)
307
 
308
 
309
+ def query_sent():
310
+ st.session_state.user_query = ""
311
+
312
  def render_history():
313
  with st.session_state.history_container:
314
 
 
341
  if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
342
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
343
  st.session_state.new_conversation = True
344
+ st.session_state.messages = []
345
 
346
  st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
347
 
 
379
  unsafe_allow_html=True
380
  )
381
 
382
+ #tab_search, tab_advanced = st.tabs(["Search", "Settings"])
383
+ tab_search = st.container()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
 
386
  with tab_search:
 
412
  #cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
413
  #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
414
 
415
+ tab_advanced = st.sidebar.expander("Settings")
416
+ with tab_advanced:
417
+ #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
418
+ #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
419
+ report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
420
+ #assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
421
+ assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
422
+ default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
423
+ clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
424
+ #prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
425
+ #prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
426
+ #prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
427
+ #scrape_boost = st.number_input('Web to API content ratio', value=1.)
428
+ top_k = st.number_input('# Top Results', value=20)
429
+ is_debug = st.checkbox("Debug output", value = False, key="debug")
430
+ openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
431
+ index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
432
+ liked_companies = st.text_input(label="liked companies", key='liked_companies')
433
+ disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
434
+
435
  if not "assistant_thread" in st.session_state:
436
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
437
 
438
+
439
+ if query != "" and not st.session_state.new_conversation:
440
  # if report_type=="standard":
441
  # prompt = default_prompt
442
  # elif report_type=="clustered":
 
453
  st.session_state.index_namespace = index_namespace
454
  st.session_state.region = region_selectbox
455
  st.session_state.country = countries_selectbox
456
+ #st.session_state.user_query = ''
457
  run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
458
+ else:
459
+ st.session_state.new_conversation = False
460