hanoch.rahimi@gmail commited on
Commit
19e6802
·
1 Parent(s): 0b0b1a7

Added filters

Browse files
Files changed (1) hide show
  1. app.py +66 -14
app.py CHANGED
@@ -1,4 +1,5 @@
1
- import pinecone
 
2
  import streamlit as st
3
  from transformers import pipeline, AutoTokenizer
4
  from sentence_transformers import SentenceTransformer
@@ -6,6 +7,8 @@ from sentence_transformers import SentenceTransformer
6
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
7
  PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
8
 
 
 
9
  @st.cache_resource
10
  def init_pinecone():
11
  pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
@@ -24,7 +27,7 @@ st.session_state.index = init_pinecone()
24
  retriever, reader, tokenizer = init_models()
25
 
26
 
27
- def card(name, description, score):
28
  return st.markdown(f"""
29
  <div class="container-fluid">
30
  <div class="row align-items-start">
@@ -36,31 +39,57 @@ def card(name, description, score):
36
  [<b>Score: </b>{score}]
37
  </span>
38
  </div>
39
- <div class="col-md-4 col-sm-4">
40
- <small>{metadata}</metadata>
 
 
 
 
 
 
41
  </div>
42
  </div>
43
  </div>
44
  """, unsafe_allow_html=True)
45
 
46
- st.title("")
47
 
48
- def run_query(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  xq = retriever.encode([query]).tolist()
50
  try:
51
- xc = st.session_state.index.query(xq, top_k=3, include_metadata=True, include_vectors = True)
52
  except:
53
  # force reload
54
  pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
55
  st.session_state.index = pinecone.Index("company-description")
56
- xc = st.session_state.index.query(xq, top_k=10, include_metadata=True, include_vectors = True)
57
 
58
  results = []
59
  for match in xc['matches']:
60
  #answer = reader(question=query, context=match["metadata"]['context'])
61
- answer = {'score': match['score']}
 
 
 
62
  answer["name"] = match["metadata"]['company_name'].strip('_description')
63
- answer["description"] = match["metadata"]['description']
64
  answer["metadata"] = match["metadata"]
65
  results.append(answer)
66
 
@@ -70,7 +99,10 @@ def run_query(query):
70
  company_name = r["name"]
71
  description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
72
  score = round(r["score"], 4)
73
- card(company_name, description, score)
 
 
 
74
 
75
  def check_password():
76
  """Returns `True` if the user had the correct password."""
@@ -100,7 +132,9 @@ def check_password():
100
  # Password correct.
101
  return True
102
 
103
- if check_password():
 
 
104
  st.write("""
105
  Search for a company in free text
106
  """)
@@ -108,9 +142,27 @@ if check_password():
108
  st.markdown("""
109
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
110
  """, unsafe_allow_html=True)
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  query = st.text_input("Search!", "")
113
 
114
  if query != "":
115
- run_query(query)
116
 
 
1
+ import json
2
+ import pinecone
3
  import streamlit as st
4
  from transformers import pipeline, AutoTokenizer
5
  from sentence_transformers import SentenceTransformer
 
7
  PINECONE_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
8
  PINE_CONE_ENVIRONMENT = st.secrets["PINE_CONE_ENVIRONMENT"] # app.pinecone.io
9
 
10
+ st.set_page_config(layout="wide")
11
+
12
  @st.cache_resource
13
  def init_pinecone():
14
  pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT) # get a free api key from app.pinecone.io
 
27
  retriever, reader, tokenizer = init_models()
28
 
29
 
30
+ def card(name, description, score, data_type, region, country):
31
  return st.markdown(f"""
32
  <div class="container-fluid">
33
  <div class="row align-items-start">
 
39
  [<b>Score: </b>{score}]
40
  </span>
41
  </div>
42
+ <div class="col-md-1 col-sm-1">
43
+ <small>{data_type}</metadata>
44
+ </div>
45
+ <div class="col-md-1 col-sm-1">
46
+ <small>{region}</metadata>
47
+ </div>
48
+ <div class="col-md-1 col-sm-1">
49
+ <small>{country}</metadata>
50
  </div>
51
  </div>
52
  </div>
53
  """, unsafe_allow_html=True)
54
 
 
55
 
56
+ def index_query(xq, top_k, regions=[], countries=[]):
57
+ #st.write(f"Regions: {regions}")
58
+ filters = []
59
+ if len(regions)>0:
60
+ filters.append({'region': {"$in": regions}})
61
+ if len(countries)>0:
62
+ filters.append({'country': {"$in": countries}})
63
+ if len(filters)==1:
64
+ filter = filters[0]
65
+ elif len(filters)>1:
66
+ filter = {"$and": filters}
67
+ else:
68
+ filter = {}
69
+ #st.write(filter)
70
+ xc = st.session_state.index.query(xq, top_k=20, filter = filter, include_metadata=True, include_vectors = True)
71
+ #xc = st.session_state.index.query(xq, top_k=top_k, include_metadata=True, include_vectors = True)
72
+ return xc
73
+
74
+ def run_query(query, scrape_boost, top_k , regions, countries):
75
  xq = retriever.encode([query]).tolist()
76
  try:
77
+ xc = index_query(xq, top_k, regions, countries)
78
  except:
79
  # force reload
80
  pinecone.init(api_key=PINECONE_KEY, environment=PINE_CONE_ENVIRONMENT)
81
  st.session_state.index = pinecone.Index("company-description")
82
+ xc = index_query(xq, top_k, regions, countries)
83
 
84
  results = []
85
  for match in xc['matches']:
86
  #answer = reader(question=query, context=match["metadata"]['context'])
87
+ score = match['score']
88
+ if 'type' in match['metadata'] and match['metadata']['type']=='description-webcontent':
89
+ score = score * scrape_boost
90
+ answer = {'score': score}
91
  answer["name"] = match["metadata"]['company_name'].strip('_description')
92
+ answer["description"] = match["metadata"]['description'] if "description" in match['metadata'] else ""
93
  answer["metadata"] = match["metadata"]
94
  results.append(answer)
95
 
 
99
  company_name = r["name"]
100
  description = r["description"].replace(company_name, f"<mark>{company_name}</mark>")
101
  score = round(r["score"], 4)
102
+ data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
103
+ region = r["metadata"]["region"]
104
+ country = r["metadata"]["country"]
105
+ card(company_name, description, score, data_type, region, country)
106
 
107
  def check_password():
108
  """Returns `True` if the user had the correct password."""
 
132
  # Password correct.
133
  return True
134
 
135
+ if True or check_password():
136
+ st.title("")
137
+
138
  st.write("""
139
  Search for a company in free text
140
  """)
 
142
  st.markdown("""
143
  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
144
  """, unsafe_allow_html=True)
145
+ with open("data/countries.json", "r") as f:
146
+ countries = json.load(f)['countries']
147
+ countries_selectbox = st.sidebar.multiselect("Country", countries, default=[])
148
+ all_regions = ('Africa', 'Europe', 'Asia & Pacific', 'North America', 'South/Latin America')
149
+ region_selectbox = st.sidebar.multiselect("Region", all_regions, default=all_regions)
150
+ scrape_boost = st.sidebar.number_input('webcontent_boost', value=2.)
151
+ top_k = st.sidebar.number_input('Top K Results', value=20)
152
+
153
+ # with st.container():
154
+ # col1, col2, col3, col4 = st.columns(4)
155
+ # with col1:
156
+ # scrape_boost = st.number_input('webcontent_boost', value=2.)
157
+ # with col2:
158
+ # top_k = st.number_input('Top K Results', value=20)
159
+ # with col3:
160
+ # regions = st.number_input('Region', value=20)
161
+ # with col4:
162
+ # countries = st.number_input('Country', value=20)
163
+
164
  query = st.text_input("Search!", "")
165
 
166
  if query != "":
167
+ run_query(query, scrape_boost, top_k, region_selectbox, countries_selectbox)
168