Spaces:
Sleeping
Sleeping
hanoch.rahimi@gmail
commited on
Commit
·
09df805
1
Parent(s):
7c3b5b3
ui changes
Browse files
app.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import json
|
2 |
|
3 |
-
|
4 |
from langchain.chains import RetrievalQA
|
5 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
from langchain.vectorstores import Pinecone
|
8 |
import pandas as pd
|
|
|
9 |
from streamlit.runtime.state import session_state
|
10 |
import openai
|
11 |
import streamlit as st
|
@@ -45,6 +45,9 @@ country_geo = pd.read_csv(COUNTRIES_FN)
|
|
45 |
st.session_state.index = utils.init_pinecone()
|
46 |
st.session_state.db_search_results = []
|
47 |
|
|
|
|
|
|
|
48 |
carddict = {
|
49 |
"name": [],
|
50 |
"company_id": [],
|
@@ -62,14 +65,21 @@ def init_models():
|
|
62 |
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
63 |
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
|
64 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
65 |
-
|
66 |
-
# client.beta.assistants.create(
|
67 |
-
# instructions=utils.assistant_instructions,
|
68 |
-
# model="gpt-4-1106-preview",
|
69 |
-
# tools=[{"type": "code_interpreter"}])
|
70 |
return retriever, tokenizer#, vectorstore
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
st.session_state.openai_client = oai.get_client()
|
74 |
retriever, tokenizer = init_models()
|
75 |
st.session_state.retriever = retriever
|
@@ -179,15 +189,15 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
179 |
with content_container:
|
180 |
for message in list(messages)[::-1]:
|
181 |
if hasattr(message, 'role'):
|
182 |
-
print(f"\n-----\nMessage: {message}\n")
|
183 |
-
with st.chat_message(name = message.role):
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
# st.session_state.messages.append({"role": "user", "content": query})
|
192 |
# st.session_state.messages.append({"role": "system", "content": m_text})
|
193 |
|
@@ -296,6 +306,9 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
|
|
296 |
use_container_width=True)
|
297 |
|
298 |
|
|
|
|
|
|
|
299 |
def render_history():
|
300 |
with st.session_state.history_container:
|
301 |
|
@@ -328,7 +341,7 @@ if utils.check_password():
|
|
328 |
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
|
329 |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
|
330 |
st.session_state.new_conversation = True
|
331 |
-
st.session_state.messages = [
|
332 |
|
333 |
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
|
334 |
|
@@ -366,32 +379,8 @@ if utils.check_password():
|
|
366 |
unsafe_allow_html=True
|
367 |
)
|
368 |
|
369 |
-
tab_search, tab_advanced = st.tabs(["Search", "Settings"])
|
370 |
-
|
371 |
-
assistants = st.session_state.openai_client.beta.assistants.list(
|
372 |
-
order="desc",
|
373 |
-
limit="20",
|
374 |
-
)
|
375 |
-
|
376 |
-
with tab_advanced:
|
377 |
-
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
|
378 |
-
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
|
379 |
-
report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
|
380 |
-
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
|
381 |
-
assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
|
382 |
-
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
|
383 |
-
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
|
384 |
-
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
|
385 |
-
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
|
386 |
-
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
|
387 |
-
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
|
388 |
-
top_k = st.number_input('# Top Results', value=20)
|
389 |
-
is_debug = st.checkbox("Debug output", value = False, key="debug")
|
390 |
-
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
|
391 |
-
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
|
392 |
-
liked_companies = st.text_input(label="liked companies", key='liked_companies')
|
393 |
-
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
|
394 |
-
|
395 |
|
396 |
|
397 |
with tab_search:
|
@@ -423,10 +412,31 @@ if utils.check_password():
|
|
423 |
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
|
424 |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
if not "assistant_thread" in st.session_state:
|
427 |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
|
428 |
|
429 |
-
|
|
|
430 |
# if report_type=="standard":
|
431 |
# prompt = default_prompt
|
432 |
# elif report_type=="clustered":
|
@@ -443,5 +453,8 @@ if utils.check_password():
|
|
443 |
st.session_state.index_namespace = index_namespace
|
444 |
st.session_state.region = region_selectbox
|
445 |
st.session_state.country = countries_selectbox
|
|
|
446 |
run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
|
|
|
|
|
447 |
|
|
|
1 |
import json
|
2 |
|
|
|
3 |
from langchain.chains import RetrievalQA
|
4 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
from langchain.vectorstores import Pinecone
|
7 |
import pandas as pd
|
8 |
+
from PIL import Image
|
9 |
from streamlit.runtime.state import session_state
|
10 |
import openai
|
11 |
import streamlit as st
|
|
|
45 |
st.session_state.index = utils.init_pinecone()
|
46 |
st.session_state.db_search_results = []
|
47 |
|
48 |
+
#st.image("resources/raized_logo.png")
|
49 |
+
assistant_avatar = Image.open('resources/raized_logo.png')
|
50 |
+
|
51 |
carddict = {
|
52 |
"name": [],
|
53 |
"company_id": [],
|
|
|
65 |
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
66 |
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
|
67 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
68 |
+
|
|
|
|
|
|
|
|
|
69 |
return retriever, tokenizer#, vectorstore
|
70 |
|
71 |
|
72 |
+
@st.cache_resource
|
73 |
+
def init_openai():
|
74 |
+
|
75 |
+
assistants = st.session_state.openai_client.beta.assistants.list(
|
76 |
+
order="desc",
|
77 |
+
limit="20",
|
78 |
+
)
|
79 |
+
return assistants
|
80 |
+
|
81 |
+
assistants = init_openai()
|
82 |
+
|
83 |
st.session_state.openai_client = oai.get_client()
|
84 |
retriever, tokenizer = init_models()
|
85 |
st.session_state.retriever = retriever
|
|
|
189 |
with content_container:
|
190 |
for message in list(messages)[::-1]:
|
191 |
if hasattr(message, 'role'):
|
192 |
+
# print(f"\n-----\nMessage: {message}\n")
|
193 |
+
# with st.chat_message(name = message.role):
|
194 |
+
# st.write(message.content[0].text.value)
|
195 |
+
if message.role == "assistant":
|
196 |
+
with st.chat_message(name = message.role, avatar = assistant_avatar):
|
197 |
+
st.write(message.content[0].text.value)
|
198 |
+
else:
|
199 |
+
with st.chat_message(name = message.role):
|
200 |
+
st.write(message.content[0].text.value)
|
201 |
# st.session_state.messages.append({"role": "user", "content": query})
|
202 |
# st.session_state.messages.append({"role": "system", "content": m_text})
|
203 |
|
|
|
306 |
use_container_width=True)
|
307 |
|
308 |
|
309 |
+
def query_sent():
|
310 |
+
st.session_state.user_query = ""
|
311 |
+
|
312 |
def render_history():
|
313 |
with st.session_state.history_container:
|
314 |
|
|
|
341 |
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
|
342 |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
|
343 |
st.session_state.new_conversation = True
|
344 |
+
st.session_state.messages = []
|
345 |
|
346 |
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
|
347 |
|
|
|
379 |
unsafe_allow_html=True
|
380 |
)
|
381 |
|
382 |
+
#tab_search, tab_advanced = st.tabs(["Search", "Settings"])
|
383 |
+
tab_search = st.container()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
|
386 |
with tab_search:
|
|
|
412 |
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
|
413 |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
|
414 |
|
415 |
+
tab_advanced = st.sidebar.expander("Settings")
|
416 |
+
with tab_advanced:
|
417 |
+
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
|
418 |
+
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
|
419 |
+
report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
|
420 |
+
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
|
421 |
+
assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
|
422 |
+
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
|
423 |
+
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
|
424 |
+
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
|
425 |
+
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
|
426 |
+
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
|
427 |
+
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
|
428 |
+
top_k = st.number_input('# Top Results', value=20)
|
429 |
+
is_debug = st.checkbox("Debug output", value = False, key="debug")
|
430 |
+
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
|
431 |
+
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
|
432 |
+
liked_companies = st.text_input(label="liked companies", key='liked_companies')
|
433 |
+
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
|
434 |
+
|
435 |
if not "assistant_thread" in st.session_state:
|
436 |
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
|
437 |
|
438 |
+
|
439 |
+
if query != "" and not st.session_state.new_conversation:
|
440 |
# if report_type=="standard":
|
441 |
# prompt = default_prompt
|
442 |
# elif report_type=="clustered":
|
|
|
453 |
st.session_state.index_namespace = index_namespace
|
454 |
st.session_state.region = region_selectbox
|
455 |
st.session_state.country = countries_selectbox
|
456 |
+
#st.session_state.user_query = ''
|
457 |
run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
|
458 |
+
else:
|
459 |
+
st.session_state.new_conversation = False
|
460 |
|