Spaces:
Sleeping
Sleeping
hanoch.rahimi@gmail
commited on
Commit
·
19e6802
1
Parent(s):
0b0b1a7
Added filters
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
import
|
|
|
2 |
import streamlit as st
|
3 |
from transformers import pipeline, AutoTokenizer
|
4 |
from sentence_transformers import SentenceTransformer
|
@@ -6,6 +7,8 @@ from sentence_transformers import SentenceTransformer
|
|
6 |
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
|
7 |
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
|
8 |
|
|
|
|
|
9 |
@st.cache_resource
|
10 |
def init_pinecone():
|
11 |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
|
@@ -24,7 +27,7 @@ st.session_state.index = init_pinecone()
|
|
24 |
retriever, reader, tokenizer = init_models()
|
25 |
|
26 |
|
27 |
-
def card(name, description, score):
|
28 |
return st.markdown(f"""
|
29 |
<div class="container-fluid">
|
30 |
<div class="row align-items-start">
|
@@ -36,31 +39,57 @@ def card(name, description, score):
|
|
36 |
[<b>Score: </b>{score}]
|
37 |
</span>
|
38 |
</div>
|
39 |
-
<div class="col-md-
|
40 |
-
<small>{
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
</div>
|
42 |
</div>
|
43 |
</div>
|
44 |
""", unsafe_allow_html=True)
|
45 |
|
46 |
-
st.title("")
|
47 |
|
48 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
xq = retriever.encode([query]).tolist()
|
50 |
try:
|
51 |
-
xc =
|
52 |
except:
|
53 |
# force reload
|
54 |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
|
55 |
st.session_state.index = pinecone.Index("company-description")
|
56 |
-
xc =
|
57 |
|
58 |
results = []
|
59 |
for match in xc['matches']:
|
60 |
#answer = reader(question=query, context=match["metadata"]['context'])
|
61 |
-
|
|
|
|
|
|
|
62 |
answer["name"] = match["metadata"]['company_name'].strip('_description')
|
63 |
-
answer["description"] = match["metadata"]['description']
|
64 |
answer["metadata"] = match["metadata"]
|
65 |
results.append(answer)
|
66 |
|
@@ -70,7 +99,10 @@ def run_query(query):
|
|
70 |
company_name = r["name"]
|
71 |
description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
|
72 |
score = round(r["score"], 4)
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
def check_password():
|
76 |
"""Returns `True` if the user had the correct password."""
|
@@ -100,7 +132,9 @@ def check_password():
|
|
100 |
# Password correct.
|
101 |
return True
|
102 |
|
103 |
-
if check_password():
|
|
|
|
|
104 |
st.write("""
|
105 |
Search for a company in free text
|
106 |
""")
|
@@ -108,9 +142,27 @@ if check_password():
|
|
108 |
st.markdown("""
|
109 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
110 |
""", unsafe_allow_html=True)
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
query = st.text_input("Search!", "")
|
113 |
|
114 |
if query != "":
|
115 |
-
run_query(query)
|
116 |
|
|
|
1 |
+
import json
|
2 |
+
import pinecone
|
3 |
import streamlit as st
|
4 |
from transformers import pipeline, AutoTokenizer
|
5 |
from sentence_transformers import SentenceTransformer
|
|
|
7 |
PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
|
8 |
PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
|
9 |
|
10 |
+
st.set_page_config(layout="wide")
|
11 |
+
|
12 |
@st.cache_resource
|
13 |
def init_pinecone():
|
14 |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
|
|
|
27 |
retriever, reader, tokenizer = init_models()
|
28 |
|
29 |
|
30 |
+
def card(name, description, score, data_type, region, country):
|
31 |
return st.markdown(f"""
|
32 |
<div class="container-fluid">
|
33 |
<div class="row align-items-start">
|
|
|
39 |
[<b>Score: </b>{score}]
|
40 |
</span>
|
41 |
</div>
|
42 |
+
<div class="col-md-1 col-sm-1">
|
43 |
+
<small>{data_type}</metadata>
|
44 |
+
</div>
|
45 |
+
<div class="col-md-1 col-sm-1">
|
46 |
+
<small>{region}</metadata>
|
47 |
+
</div>
|
48 |
+
<div class="col-md-1 col-sm-1">
|
49 |
+
<small>{country}</metadata>
|
50 |
</div>
|
51 |
</div>
|
52 |
</div>
|
53 |
""", unsafe_allow_html=True)
|
54 |
|
|
|
55 |
|
56 |
+
def index_query(xq, top_k, regions=[], countries=[]):
|
57 |
+
#st.write(f"Regions: {regions}")
|
58 |
+
filters = []
|
59 |
+
if len(regions)>0:
|
60 |
+
filters.append({'region': {"$in": regions}})
|
61 |
+
if len(countries)>0:
|
62 |
+
filters.append({'country': {"$in": countries}})
|
63 |
+
if len(filters)==1:
|
64 |
+
filter = filters[0]
|
65 |
+
elif len(filters)>1:
|
66 |
+
filter = {"$and": filters}
|
67 |
+
else:
|
68 |
+
filter = {}
|
69 |
+
#st.write(filter)
|
70 |
+
xc = st.session_state.index.query(xq, top_k=20, filter = filter, include_metadata=True, include_vectors = True)
|
71 |
+
#xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
|
72 |
+
return xc
|
73 |
+
|
74 |
+
def run_query(query, scrape_boost, top_k , regions, countries):
|
75 |
xq = retriever.encode([query]).tolist()
|
76 |
try:
|
77 |
+
xc = index_query(xq, top_k, regions, countries)
|
78 |
except:
|
79 |
# force reload
|
80 |
pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
|
81 |
st.session_state.index = pinecone.Index("company-description")
|
82 |
+
xc = index_query(xq, top_k, regions, countries)
|
83 |
|
84 |
results = []
|
85 |
for match in xc['matches']:
|
86 |
#answer = reader(question=query, context=match["metadata"]['context'])
|
87 |
+
score = match['score']
|
88 |
+
if 'type' in match['metadata'] and match['metadata']['type']=='description-webcontent':
|
89 |
+
score = score * scrape_boost
|
90 |
+
answer = {'score': score}
|
91 |
answer["name"] = match["metadata"]['company_name'].strip('_description')
|
92 |
+
answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else ""
|
93 |
answer["metadata"] = match["metadata"]
|
94 |
results.append(answer)
|
95 |
|
|
|
99 |
company_name = r["name"]
|
100 |
description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
|
101 |
score = round(r["score"], 4)
|
102 |
+
data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
|
103 |
+
region = r["metadata"]["region"]
|
104 |
+
country = r["metadata"]["country"]
|
105 |
+
card(company_name, description, score, data_type, region, country)
|
106 |
|
107 |
def check_password():
|
108 |
"""Returns `True` if the user had the correct password."""
|
|
|
132 |
# Password correct.
|
133 |
return True
|
134 |
|
135 |
+
if True or check_password():
|
136 |
+
st.title("")
|
137 |
+
|
138 |
st.write("""
|
139 |
Search for a company in free text
|
140 |
""")
|
|
|
142 |
st.markdown("""
|
143 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
144 |
""", unsafe_allow_html=True)
|
145 |
+
with open("data/countries.json", "r") as f:
|
146 |
+
countries = json.load(f)['countries']
|
147 |
+
countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
|
148 |
+
all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
|
149 |
+
region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
|
150 |
+
scrape_boost = st.sidebar.number_input('webcontent_boost', value=2.)
|
151 |
+
top_k = st.sidebar.number_input('Top K Results', value=20)
|
152 |
+
|
153 |
+
# with st.container():
|
154 |
+
# col1, col2, col3, col4 = st.columns(4)
|
155 |
+
# with col1:
|
156 |
+
# scrape_boost = st.number_input('webcontent_boost', value=2.)
|
157 |
+
# with col2:
|
158 |
+
# top_k = st.number_input('Top K Results', value=20)
|
159 |
+
# with col3:
|
160 |
+
# regions = st.number_input('Region', value=20)
|
161 |
+
# with col4:
|
162 |
+
# countries = st.number_input('Country', value=20)
|
163 |
+
|
164 |
query = st.text_input("Search!", "")
|
165 |
|
166 |
if query != "":
|
167 |
+
run_query(query, scrape_boost, top_k, region_selectbox, countries_selectbox)
|
168 |
|