Spaces:
Running
Running
import json | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.vectorstores import Pinecone | |
from streamlit.runtime.state import session_state | |
import openai | |
import pinecone | |
import streamlit as st | |
from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import streamlit.components.v1 as components | |
import utils | |
import openai_utils as oai | |
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] # app.pinecone.io | |
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io | |
model_name = 'text-embedding-ada-002' | |
embed = OpenAIEmbeddings( | |
model=model_name, | |
openai_api_key=OPENAI_API_KEY | |
) | |
st.set_page_config(layout="wide", initial_sidebar_state="collapsed") | |
def init_pinecone(): | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io | |
return pinecone.Index("dompany-description") | |
st.session_state.index = init_pinecone() | |
def init_models(): | |
#retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
retriever = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
#reader = pipeline(tokenizer=model_name, model=model_name, task='question-answering') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
#vectorstore = Pinecone(st.session_state.index, embed.embed_query, text_field) | |
# client.beta.assistants.create( | |
# instructions=utils.assistant_instructions, | |
# model="gpt-4-1106-preview", | |
# tools=[{"type": "code_interpreter"}]) | |
return retriever, tokenizer#, vectorstore | |
st.session_state.openai_client = oai.get_client() | |
retriever, tokenizer = init_models() | |
#st.session_state.messages = [{"role":"system", "content":"You are an assistant who helps users find startups to invest in."}] | |
def card(company_id, name, description, score, data_type, region, country, metadata, is_debug): | |
if 'Summary' in metadata: | |
description = metadata['Summary'] | |
customer_problem = metadata['Customer problem'] if 'Customer problem' in metadata else "" | |
target_customer = metadata['Target customer'] if 'Target customer' in metadata else "" | |
business_model = "" | |
if 'Business model' in metadata: | |
try: | |
business_model = metadata['Business model'] | |
#business_model = json.loads(metadata['Business model']) | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
markdown = f""" | |
<div class="row align-items-start" style="padding-bottom:10px;"> | |
<div class="col-md-8 col-sm-8"> | |
<b>{name} (<a href='https://{company_id}'>website</a>).</b> | |
<p style="">{description}</p> | |
</div> | |
<div class="col-md-1 col-sm-1"><span>{country}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{customer_problem}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{target_customer}</span></div> | |
<div class="col-md-1 col-sm-1"><span>{business_model}</span></div> | |
""" | |
if is_debug: | |
markdown = markdown + f""" | |
<div class="col-md-1 col-sm-1" style="display:none;"> | |
<button type='button' onclick="like_company({company_id});">Like</button> | |
<button type='button' onclick="dislike_company({company_id});">DisLike</button> | |
</div> | |
<div class="col-md-1 col-sm-1"> | |
<span>{data_type}</span> | |
<span>[Score: {score}</span> | |
</div> | |
""" | |
markdown = markdown + "</div>" | |
#print(f" markdown for {company_id}\n{markdown}") | |
return markdown | |
def index_query(xq, top_k, regions=[], countries=[], index_namespace="websummarized"): | |
#st.write(f"Regions: {regions}") | |
filters = [] | |
if len(regions)>0: | |
filters.append({'region': {"$in": regions}}) | |
if len(countries)>0: | |
filters.append({'country': {"$in": countries}}) | |
if len(filters)==1: | |
filter = filters[0] | |
elif len(filters)>1: | |
filter = {"$and": filters} | |
else: | |
filter = {} | |
#st.write(filter) | |
xc = st.session_state.index.query(xq, namespace=index_namespace, top_k=20, filter = filter, include_metadata=True, include_vectors = False) | |
#xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True) | |
return xc | |
def search_index(query, top_k, regions, countries): | |
xq = retriever.encode([query]).tolist() | |
try: | |
xc = index_query(xq, top_k, regions, countries) | |
except: | |
# force reload | |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) | |
st.session_state.index = pinecone.Index("company-description") | |
xc = index_query(xq, top_k, regions, countries, index_namespace) | |
results = [] | |
for match in xc['matches']: | |
#answer = reader(question=query, context=match["metadata"]['context']) | |
score = match['score'] | |
# if 'type' in match['metadata'] and match['metadata']['type']!='description-webcontent' and scrape_boost>0: | |
# score = score / scrape_boost | |
answer = {'score': score, 'metadata': match['metadata']} | |
if match['id'].endswith("_description"): | |
answer['id'] = match['id'][:-12] | |
elif match['id'].endswith("_webcontent"): | |
answer['id'] = match['id'][:-11] | |
else: | |
answer['id'] = match['id'] | |
answer["name"] = match["metadata"]['company_name'] | |
answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else "" | |
data = None | |
data = {"Summary": match["metadata"]["summary"]} | |
if 'summary' in match['metadata']: | |
try: | |
data = json.loads(match["metadata"]["summary"]) | |
except Exception as e: | |
pass | |
answer['data'] = data | |
results.append(answer) | |
return results | |
def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model): | |
#Summarize the results | |
# prompt_txt = """ | |
# You are a venture capitalist analyst. Below are descriptions of startup companies that are relevant to the user with their relevancy score. | |
# Create a summarized report focusing on the top3 companies. | |
# For every company find its uniqueness over the other companies. Use only information from the descriptions. | |
# """ | |
if report_type=="guided": | |
prompt_txt = utils.query_finetune_prompt + """ | |
User query: {query} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"]) | |
prompt = prompt_template.format(query = query) | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False) | |
print(f"Keywords: {m_text}") | |
results = search_index(m_text, top_k, regions, countries) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
ntokens = len(descriptions.split(" ")) | |
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
prompt_txt = utils.summarization_prompt + """ | |
User query: {query} | |
Company descriptions: {descriptions} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
prompt = prompt_template.format(descriptions = descriptions, query = query) | |
print(f"==============================\nPrompt:\n{prompt}\n==============================\n") | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
m_text | |
elif report_type=="company_list": # or st.session_state.new_conversation: | |
results = search_index(query, top_k, regions, countries) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
elif report_type=="assistant": | |
results = search_index(query, top_k, regions, countries) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
ntokens = len(descriptions.split(" ")) | |
# prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt | |
# prompt_txt = prompt + """ | |
# User query: {query} | |
# Company descriptions: {descriptions} | |
# """ | |
# prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
# prompt = prompt_template.format(descriptions = descriptions, query = query) | |
#print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n") | |
prompt = query | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
m_text | |
st.session_state.messages.append({"role": "user", "content": query}) | |
i = m_text.find("-----") | |
i = 0 if i<0 else i | |
st.session_state.messages.append({"role": "system", "content": m_text[:i]}) | |
else: | |
st.session_state.new_conversation = False | |
results = search_index(query, top_k, regions, countries) | |
descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']]) | |
ntokens = len(descriptions.split(" ")) | |
print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}") | |
prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt | |
prompt_txt = prompt + """ | |
User query: {query} | |
Company descriptions: {descriptions} | |
""" | |
prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"]) | |
prompt = prompt_template.format(descriptions = descriptions, query = query) | |
print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n") | |
m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0) | |
m_text | |
st.session_state.messages.append({"role": "user", "content": query}) | |
i = m_text.find("-----") | |
i = 0 if i<0 else i | |
st.session_state.messages.append({"role": "system", "content": m_text[:i]}) | |
render_history() | |
# for message in st.session_state.messages: | |
# with st.chat_message(message["role"]): | |
# st.markdown(message["content"]) | |
# print(f"History: \n {st.session_state.messages}") | |
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True) | |
names = [] | |
# list_html = """ | |
# <h2>Companies list</h2> | |
# <div class="container-fluid"> | |
# <div class="row align-items-start" style="padding-bottom:10px;"> | |
# <div class="col-md-8 col-sm-8"> | |
# <span>Company</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Country</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Customer Problem</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# <span>Business Model</span> | |
# </div> | |
# <div class="col-md-1 col-sm-1"> | |
# Actions | |
# </div> | |
# </div> | |
# """ | |
list_html = "<div class='container-fluid'>" | |
for r in sorted_results: | |
company_name = r["name"] | |
if company_name in names: | |
continue | |
else: | |
names.append(company_name) | |
description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>") | |
if description is None or len(description.strip())<10: | |
continue | |
score = round(r["score"], 4) | |
data_type = r["metadata"]["type"] if "type" in r["metadata"] else "" | |
region = r["metadata"]["region"] | |
country = r["metadata"]["country"] | |
company_id = r["metadata"]["company_id"] | |
list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug) | |
list_html = list_html + '</div>' | |
st.markdown(list_html, unsafe_allow_html=True) | |
def render_history(): | |
with st.session_state.history_container: | |
s = f""" | |
<div style='overflow: hidden; padding:10px 0px;'> | |
<div id="chat_history" style='overflow-y: scroll;height: 200px;'> | |
""" | |
for m in st.session_state.messages: | |
#print(f"Printing message\t {m['role']}: {m['content']}") | |
s = s + f"<div class='chat_message'><b>{m['role']}</b>: {m['content']}</div>" | |
s = s + f"""</div> | |
</div> | |
<script> | |
var el = document.getElementById("chat_history"); | |
el.scrollTop = el.scrollHeight; | |
</script> | |
""" | |
components.html(s, height=220) | |
#st.markdown(s, unsafe_allow_html=True) | |
if utils.check_password(): | |
st.markdown("<script language='javascript'>console.log('scrolling');</script>", unsafe_allow_html=True) | |
if st.sidebar.button("New Conversation") or "messages" not in st.session_state: | |
st.session_state.new_conversation = True | |
st.session_state.messages = [{"role":"system", "content":"Hello. I'm your startups discovery assistant."}] | |
st.title("Raized- Startups discovery demo") | |
#st.write("Search for a company in free text. Describe the type of company you are looking for, the problem they solve and the solution they provide. You can also copy in the description of a similar company to kick off the search.") | |
st.markdown(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
""", unsafe_allow_html=True) | |
with open("data/countries.json", "r") as f: | |
countries = json.load(f)['countries'] | |
header = st.sidebar.markdown("Filters") | |
#new_conversation = st.sidebar.button("New Conversation", key="new_conversation") | |
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[]) | |
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America') | |
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions) | |
all_bizmodels = ('B2B', 'B2C', 'eCommerce & Marketplace', 'Manufacturing', 'SaaS', 'Advertising', 'Commission', 'Subscription') | |
bizmodel_selectbox = st.sidebar.multiselect("Business Model", all_bizmodels, default=all_bizmodels) | |
# with st.container(): | |
# col1, col2, col3, col4 = st.columns(4) | |
# with col1: | |
# scrape_boost = st.number_input('webcontent_boost', value=2.) | |
# with col2: | |
# top_k = st.number_input('Top K Results', value=20) | |
# with col3: | |
# regions = st.number_input('Region', value=20) | |
# with col4: | |
# countries = st.number_input('Country', value=20) | |
st.markdown( | |
''' | |
<script> | |
function like_company(company_id) { | |
console.log("Company " + company_id + " Liked!"); | |
} | |
function dislike_company(company_id) { | |
console.log("Company " + company_id + " Disliked!"); | |
} | |
</script> | |
<style> | |
.sidebar .sidebar-content {{ | |
width: 375px; | |
}} | |
</style> | |
''', | |
unsafe_allow_html=True | |
) | |
tab_search, tab_advanced = st.tabs(["Search", "Settings"]) | |
with tab_advanced: | |
#prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", ) | |
#prompt_title_editable = st.text_input("Title", key="prompt_title_editable") | |
report_type = st.selectbox(label="Response Type", options=["assistant", "standard", "guided", "company_list", "clustered"], index=0) | |
assistant_id = st.text_input(label="Assistant ID", key="assistant_id", value="asst_fkZtxo127nxKOCcwrwznuCs2") | |
default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content") | |
clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content") | |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
#prompt_delete = st.button("Del", on_click = utils.del_prompt(prompt_title_editable)) | |
#prompt_save = st.button("Save", on_click = utils.save_prompt(prompt_title_editable, prompt)) | |
#scrape_boost = st.number_input('Web to API content ratio', value=1.) | |
top_k = st.number_input('# Top Results', value=20) | |
is_debug = st.checkbox("Debug output", value = False, key="debug") | |
openai_model = st.selectbox(label="Model", options=["gpt-4-1106-preview", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k"], index=0, key="openai_model") | |
index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0) | |
liked_companies = st.text_input(label="liked companies", key='liked_companies') | |
disliked_companies = st.text_input(label="disliked companies", key='disliked_companies') | |
with tab_search: | |
#report_type = st.multiselect("Report Type", utils.get_prompts(), key="search_prompts_multiselect") | |
st.session_state.history_container = st.container() | |
query = st.text_input("Search!", "") | |
#cluster = st.checkbox("Cluster the results", value = False, key = "cluster") | |
#prompt_new = st.button("New", on_click = _prompt(prompt_title, prompt)) | |
if query != "": | |
# if report_type=="standard": | |
# prompt = default_prompt | |
# elif report_type=="clustered": | |
# prompt = clustering_prompt | |
# elif report_type=="guided": | |
# prompt = "guided" | |
# else: | |
# prompt = "" | |
oai.start_conversation() | |
#st.session_state.assistant_id = assistant_id | |
st.session_state.report_type = report_type | |
run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model) | |