semsearch / app.py
hanoch.rahimi@gmail
assistant wip
d54eee9
raw
history blame
19.1 kB
import json
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Pinecone
from streamlit.runtime.state import session_state
import openai
import pinecone
import streamlit as st
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import streamlit.components.v1 as components
import utils
import openai_utils as oai
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
model_name = 'text-embedding-ada-002'
embed = OpenAIEmbeddings(
model=model_name,
openai_api_key=OPENAI_API_KEY
)
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
@st.cache_resource
def init_pinecone():
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
return pinecone.Index("dompany-description")
st.session_state.index = init_pinecone()
@st.cache_resource
def init_models():
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
model_name = "sentence-transformers/all-MiniLM-L6-v2"
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering')
tokenizer = AutoTokenizer.from_pretrained(model_name)
#vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field)
# client.beta.assistants.create(
# instructions=utils.assistant_instructions,
# model="gpt-4-1106-preview",
# tools=[{"type": "code_interpreter"}])
return retriever, tokenizer#, vectorstore
st.session_state.openai_client = oai.get_client()
retriever, tokenizer = init_models()
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}]
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug):
if 'Summary' in metadata:
description = metadata['Summary']
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else ""
target_customer = metadata['Target customer'] if 'Target customer' in metadata else ""
business_model = ""
if 'Business model' in metadata:
try:
business_model = metadata['Business model']
#business_model = json.loads(metadata['Business model'])
except Exception as e:
print(f"An error occurred: {str(e)}")
markdown = f"""
<div class="row align-items-start" style="padding-bottom:10px;">
<div class="col-md-8 col-sm-8">
<b>{name} (<a href='https://{company_id}'>website</a>).</b>
<p style="">{description}</p>
</div>
<div class="col-md-1 col-sm-1"><span>{country}</span></div>
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div>
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div>
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div>
"""
if is_debug:
markdown = markdown + f"""
<div class="col-md-1 col-sm-1" style="display:none;">
<button type='button' onclick="like_company({company_id});">Like</button>
<button type='button' onclick="dislike_company({company_id});">DisLike</button>
</div>
<div class="col-md-1 col-sm-1">
<span>{data_type}</span>
<span>[Score: {score}</span>
</div>
"""
markdown = markdown + "</div>"
#print(f" markdown for {company_id}\n{markdown}")
return markdown
def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummarized"):
#st.write(f"Regions: {regions}")
filters = []
if len(regions)>0:
filters.append({'region': {"$in": regions}})
if len(countries)>0:
filters.append({'country': {"$in": countries}})
if len(filters)==1:
filter = filters[0]
elif len(filters)>1:
filter = {"$and": filters}
else:
filter = {}
#st.write(filter)
xc = st.session_state.index.query(xq, namespace=index_namespace, top_k=20, filter = filter, include_metadata=True, include_vectors = False)
#xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
return xc
def search_index(query, top_k, regions, countries):
xq = retriever.encode([query]).tolist()
try:
xc = index_query(xq, top_k, regions, countries)
except:
# force reload
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
st.session_state.index = pinecone.Index("company-description")
xc = index_query(xq, top_k, regions, countries, index_namespace)
results = []
for match in xc['matches']:
#answer = reader(question=query, context=match["metadata"]['context'])
score = match['score']
# if 'type' in match['metadata'] and match['metadata']['type']!='description-webcontent' and scrape_boost>0:
# score = score / scrape_boost
answer = {'score': score, 'metadata': match['metadata']}
if match['id'].endswith("_description"):
answer['id'] = match['id'][:-12]
elif match['id'].endswith("_webcontent"):
answer['id'] = match['id'][:-11]
else:
answer['id'] = match['id']
answer["name"] = match["metadata"]['company_name']
answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else ""
data = None
data = {"Summary": match["metadata"]["summary"]}
if 'summary' in match['metadata']:
try:
data = json.loads(match["metadata"]["summary"])
except Exception as e:
pass
answer['data'] = data
results.append(answer)
return results
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model):
#Summarize the results
# prompt_txt = """
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score.
# Create a summarized report focusing on the top3 companies.
# For every company find its uniqueness over the other companies. Use only information from the descriptions.
# """
if report_type=="guided":
prompt_txt = utils.query_finetune_prompt + """
User query: {query}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
prompt = prompt_template.format(query = query)
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
print(f"Keywords: {m_text}")
results = search_index(m_text, top_k, regions, countries)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt_txt = utils.summarization_prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
elif report_type=="company_list": # or st.session_state.new_conversation:
results = search_index(query, top_k, regions, countries)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
elif report_type=="assistant":
results = search_index(query, top_k, regions, countries)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
# prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
# prompt_txt = prompt + """
# User query: {query}
# Company descriptions: {descriptions}
# """
# prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
# prompt = prompt_template.format(descriptions = descriptions, query = query)
#print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
prompt = query
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
st.session_state.messages.append({"role": "user", "content": query})
i = m_text.find("-----")
i = 0 if i<0 else i
st.session_state.messages.append({"role": "system", "content": m_text[:i]})
else:
st.session_state.new_conversation = False
results = search_index(query, top_k, regions, countries)
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
ntokens = len(descriptions.split(" "))
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
prompt_txt = prompt + """
User query: {query}
Company descriptions: {descriptions}
"""
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
prompt = prompt_template.format(descriptions = descriptions, query = query)
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
m_text
st.session_state.messages.append({"role": "user", "content": query})
i = m_text.find("-----")
i = 0 if i<0 else i
st.session_state.messages.append({"role": "system", "content": m_text[:i]})
render_history()
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# print(f"History: \n {st.session_state.messages}")
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
names = []
# list_html = """
# <h2>Companies list</h2>
# <div class="container-fluid">
# <div class="row align-items-start" style="padding-bottom:10px;">
# <div class="col-md-8 col-sm-8">
# <span>Company</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Country</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Customer Problem</span>
# </div>
# <div class="col-md-1 col-sm-1">
# <span>Business Model</span>
# </div>
# <div class="col-md-1 col-sm-1">
# Actions
# </div>
# </div>
# """
list_html = "<div class='container-fluid'>"
for r in sorted_results:
company_name = r["name"]
if company_name in names:
continue
else:
names.append(company_name)
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
if description is None or len(description.strip())<10:
continue
score = round(r["score"], 4)
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
region = r["metadata"]["region"]
country = r["metadata"]["country"]
company_id = r["metadata"]["company_id"]
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
list_html = list_html + '</div>'
st.markdown(list_html, unsafe_allow_html=True)
def render_history():
with st.session_state.history_container:
s = f"""
<div style='overflow: hidden; padding:10px 0px;'>
<div id="chat_history" style='overflow-y: scroll;height: 200px;'>
"""
for m in st.session_state.messages:
#print(f"Printing message\t {m['role']}: {m['content']}")
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>"
s = s + f"""</div>
</div>
<script>
var el = document.getElementById("chat_history");
el.scrollTop = el.scrollHeight;
</script>
"""
components.html(s, height=220)
#st.markdown(s, unsafe_allow_html=True)
if utils.check_password():
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True)
if st.sidebar.button("New Conversation") or "messages" not in st.session_state:
st.session_state.new_conversation = True
st.session_state.messages = [{"role":"system", "content":"Hello. I'm your startups discovery assistant."}]
st.title("Raized- Startups discovery demo")
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with open("data/countries.json", "r") as f:
countries = json.load(f)['countries']
header = st.sidebar.markdown("Filters")
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation")
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription')
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels)
# with st.container():
# col1, col2, col3, col4 = st.columns(4)
# with col1:
# scrape_boost = st.number_input('webcontent_boost', value=2.)
# with col2:
# top_k = st.number_input('Top K Results', value=20)
# with col3:
# regions = st.number_input('Region', value=20)
# with col4:
# countries = st.number_input('Country', value=20)
st.markdown(
'''
<script>
function like_company(company_id) {
console.log("Company " + company_id + " Liked!");
}
function dislike_company(company_id) {
console.log("Company " + company_id + " Disliked!");
}
</script>
<style>
.sidebar .sidebar-content {{
width: 375px;
}}
</style>
''',
unsafe_allow_html=True
)
tab_search, tab_advanced = st.tabs(["Search", "Settings"])
with tab_advanced:
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0)
assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value="asst_fkZtxo127nxKOCcwrwznuCs2")
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable))
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt))
#scrape_boost = st.number_input('Web to API content ratio', value=1.)
top_k = st.number_input('# Top Results', value=20)
is_debug = st.checkbox("Debug output", value = False, key="debug")
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model")
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
liked_companies = st.text_input(label="liked companies", key='liked_companies')
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
with tab_search:
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect")
st.session_state.history_container = st.container()
query = st.text_input("Search!", "")
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster")
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt))
if query != "":
# if report_type=="standard":
# prompt = default_prompt
# elif report_type=="clustered":
# prompt = clustering_prompt
# elif report_type=="guided":
# prompt = "guided"
# else:
# prompt = ""
oai.start_conversation()
#st.session_state.assistant_id = assistant_id
st.session_state.report_type = report_type
run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model)