semsearch / app.py
hanoch.rahimi@gmail
increase number retreived items, prevent resending the same query
ba1f3e2
raw
history blame
20.9 kB
import json
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Pinecone
import pandas as pd
from PIL import Image
from streamlit.runtime.state import session_state
import openai
import streamlit as st
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import streamlit.components.v1 as components
st.set_page_config(
layout="wide",
initial_sidebar_state="collapsed",
page_title="RaizedAI Startup Discovery Assistant",
#page_icon=":robot:"
)
import utils
import openai_utils as oai
from streamlit_extras.stylable_container import stylable_container
# OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
#model_name = 'text-embedding-ada-002'
# embed = OpenAIEmbeddings(
# model=model_name,
# openai_api_key=OPENAI_API_KEY
# )
#"🤖",
COUNTRIES_FN="data/countries.csv"
country_geo = pd.read_csv(COUNTRIES_FN)
st.session_state.index = utils.init_pinecone()
st.session_state.db_search_results = []
#st.image("resources/raized_logo.png")
assistant_avatar = Image.open('resources/raized_logo.png')
carddict = {
"name": [],
"company_id": [],
"description": [],
"country": [],
"customer_problem": [],
"target_customer": [],
"business_model": []
}
@st.cache_resource
def init_models():
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
return retriever, tokenizer#, vectorstore
@st.cache_resource
def init_openai():
assistants = st.session_state.openai_client.beta.assistants.list(
order="desc",
limit="20",
)
return assistants
st.session_state.openai_client = oai.get_client()
assistants = init_openai()
retriever, tokenizer = init_models()
st.session_state.retriever = retriever
# AVATAR_PATHS = {"assistant": st.image("resources/raized_logo.png"),
# "user": "👩‍⚖️"}
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
if 'Summary' in metadata:
description = metadata['Summary']
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else ""
target_customer = metadata['Target customer'] if 'Target customer' in metadata else ""
business_model = ""
if 'Business model' in metadata:
try:
business_model = metadata['Business model']
#business_model = json.loads(metadata['Business model'])
except Exception as e:
print(f"An error occurred: {str(e)}")
markdown = f"""
<div class="row align-items-start" style="padding-bottom:10px;">
<div class="col-md-8 col-sm-8">
<b>{name} (<a href='https://{company_id}'>website</a>).</b>
<p style="">{description}</p>
</div>
<div class="col-md-1 col-sm-1"><span>{country}</span></div>
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div>
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div>
"""
business_model_str = ", ".join(business_model)
company_id_url = "https://" + company_id
carddict["name"].append(name)
carddict["company_id"].append(company_id_url)
carddict["description"].append(description)
carddict["country"].append(country)
carddict["customer_problem"].append(customer_problem)
carddict["target_customer"].append(target_customer)
carddict["business_model"].append(business_model_str)
if is_debug:
markdown = markdown + f"""
<div class="col-md-1 col-sm-1" style="display:none;">
<button type='button' onclick="like_company({company_id});">Like</button>
<button type='button' onclick="dislike_company({company_id});">DisLike</button>
</div>
<div class="col-md-1 col-sm-1">
<span>{data_type}</span>
<span>[Score: {score}</span>
</div>
"""
markdown = markdown + "</div>"
#print(f" markdown for {company_id}\n{markdown}")
return markdown
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
#Summarize the results
# prompt_txt = """
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score.
# Create a summarized report focusing on the top3 companies.
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
# """
content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
if report_type=="guided":
prompt_txt = utils.query_finetune_prompt + """
User query: {query}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
prompt = prompt_template.format(query = query)
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
print(f"Keywords: {m_text}")
results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt_txt = utils.summarization_prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
elif report_type=="company_list": # or st.session_state.new_conversation:
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
elif report_type=="assistant":
#results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
#descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
messages = oai.call_assistant(query, engine=openai_model)
st.session_state.messages = messages
results = st.session_state.db_search_results
if not messages is None:
with content_container:
for message in list(messages)[::-1]:
if hasattr(message, 'role'):
# print(f"\n-----\nMessage: {message}\n")
# with st.chat_message(name = message.role):
# st.write(message.content[0].text.value)
if message.role == "assistant":
with st.chat_message(name = message.role, avatar = assistant_avatar):
st.write(message.content[0].text.value)
else:
with st.chat_message(name = message.role):
st.write(message.content[0].text.value)
# st.session_state.messages.append({"role": "user", "content": query})
# st.session_state.messages.append({"role": "system", "content": m_text})
else:
st.session_state.new_conversation = False
results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
prompt_txt = prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
st.session_state.messages.append({"role": "user", "content": query})
i = m_text.find("-----")
i = 0 if i<0 else i
st.session_state.messages.append({"role": "system", "content": m_text[:i]})
#render_history()
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# print(f"History: \n {st.session_state.messages}")
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
names = []
# list_html = """
# <h2>Companies list</h2>
# <div class="container-fluid">
# <div class="row align-items-start" style="padding-bottom:10px;">
# <div class="col-md-8 col-sm-8">
# <span>Company</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Country</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Customer Problem</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Business Model</span>
# </div>
# <div class="col-md-1 col-sm-1">
# Actions
# </div>
# </div>
# """
list_html = "<div class='container-fluid'>"
locations = set()
for r in sorted_results:
company_name = r["name"]
if company_name in names:
continue
else:
names.append(company_name)
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
if description is None or len(description.strip())<10:
continue
score = round(r["score"], 4)
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
region = r["metadata"]["region"]
country = r["metadata"]["country"]
company_id = r["metadata"]["company_id"]
locations.add(country)
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
list_html = list_html + '</div>'
pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
if len(pins)>0:
with st.expander("Map view"):
st.map(pins)
#st.markdown(list_html, unsafe_allow_html=True)
df = pd.DataFrame.from_dict(carddict, orient="columns")
if len(df)>0:
df.index += 1
with content_container:
st.dataframe(df,
hide_index=False,
column_config ={
"name": st.column_config.TextColumn("Name"),
"company_id": st.column_config.LinkColumn("Link"),
"description": st.column_config.TextColumn("Description"),
"country": st.column_config.TextColumn("Country", width="small"),
"customer_problem": st.column_config.TextColumn("Customer problem"),
"target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
"business_model": st.column_config.TextColumn(label="Business model")
},
use_container_width=True)
st.session_state.last_user_query = query
def query_sent():
st.session_state.user_query = ""
def render_history():
with st.session_state.history_container:
s = f"""
<div style='overflow: hidden; padding:10px 0px;'>
<div id="chat_history" style='overflow-y: scroll;height: 200px;'>
"""
for m in st.session_state.messages:
#print(f"Printing message\t {m['role']}: {m['content']}")
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
s = s + f"""</div>
</div>
<script>
var el = document.getElementById("chat_history");
el.scrollTop = el.scrollHeight;
</script>
"""
components.html(s, height=220)
#st.markdown(s, unsafe_allow_html=True)
if not 'submitted_query' in st.session_state:
st.session_state.submitted_query = ''
if not 'messages' in st.session_state:
st.session_state.messages = []
if not 'last_user_query' in st.session_state:
st.session_state.last_user_query = ''
if utils.check_password():
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
st.session_state.new_conversation = True
st.session_state.messages = []
st.markdown("<h1 style='text-align: center; color: red; position: relative; top: -3rem;'>Raized.AI – Startups discovery demo</h1>", unsafe_allow_html=True)
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with open("data/countries.json", "r") as f:
countries = json.load(f)['countries']
header = st.sidebar.markdown("Filters")
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation")
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription')
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels)
st.markdown(
'''
<script>
function like_company(company_id) {
console.log("Company " + company_id + " Liked!");
}
function dislike_company(company_id) {
console.log("Company " + company_id + " Disliked!");
}
</script>
<style>
.sidebar .sidebar-content {{
width: 375px;
}}
</style>
''',
unsafe_allow_html=True
)
#tab_search, tab_advanced = st.tabs(["Search", "Settings"])
tab_search = st.container()
with tab_search:
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
st.session_state.history_container = st.container()
with stylable_container(
key="query_panel",
css_styles="""
.stTextInput {
position: fixed;
bottom: 0px;
background: white;
z-index: 1000;
padding-bottom: 2rem;
padding-left: 1rem;
padding-right: 1rem;
padding-top: 1rem;
border-top: 1px solid whitesmoke;
height: 8rem;
border-radius: 8px 8px 8px 8px;
box-shadow: 0 -3px 3px whitesmoke;
}
""",
):
query = st.text_input(key="user_query",
label="Enter your query",
placeholder="Tell me what startups you are looking for", label_visibility="collapsed")
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
tab_advanced = st.sidebar.expander("Settings")
with tab_advanced:
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
#assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value = "asst_NHoxEosVlemDY7y5TYg8ftku") #value="asst_fkZtxo127nxKOCcwrwznuCs2")
assistant_id = st.selectbox(label="Assistant", options = [f"{a.id}|||{a.name}" for a in assistants])
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
top_k = st.number_input('# Top Results', value=30)
is_debug = st.checkbox("Debug output", value = False, key="debug")
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
liked_companies = st.text_input(label="liked companies", key='liked_companies')
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
if not "assistant_thread" in st.session_state:
st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
if query != "" and not st.session_state.new_conversation:
# if report_type=="standard":
# prompt = default_prompt
# elif report_type=="clustered":
# prompt = clustering_prompt
# elif report_type=="guided":
# prompt = "guided"
# else:
# prompt = ""
#oai.start_conversation()
i = assistant_id.index("|||")
st.session_state.assistant_id = assistant_id[:i]
st.session_state.report_type = report_type
st.session_state.top_k = top_k
st.session_state.index_namespace = index_namespace
st.session_state.region = region_selectbox
st.session_state.country = countries_selectbox
run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)
else:
st.session_state.new_conversation = False