Update app.py
Browse filesAdded e5 embedding model
app.py
CHANGED
@@ -137,8 +137,15 @@ def bi_encode(bi_enc,passages):
|
|
137 |
|
138 |
#Compute the embeddings using the multi-process pool
|
139 |
with st.spinner('Encoding passages into a vector space...'):
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
|
144 |
|
@@ -178,7 +185,7 @@ def bm25_api(passages):
|
|
178 |
|
179 |
return bm25
|
180 |
|
181 |
-
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1","neeva/query2query"]
|
182 |
|
183 |
def display_df_as_table(model,top_k,score='score'):
|
184 |
# Display the df with text and scores as a table
|
@@ -204,7 +211,7 @@ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5
|
|
204 |
|
205 |
# This function will search all wikipedia articles for passages that
|
206 |
# answer the query
|
207 |
-
def search_func(query, top_k=top_k):
|
208 |
|
209 |
global bi_encoder, cross_encoder
|
210 |
|
@@ -229,6 +236,8 @@ def search_func(query, top_k=top_k):
|
|
229 |
bm25_df = display_df_as_table(bm25_hits,top_k)
|
230 |
st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
|
231 |
|
|
|
|
|
232 |
##### Sematic Search #####
|
233 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
234 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
|
|
137 |
|
138 |
#Compute the embeddings using the multi-process pool
|
139 |
with st.spinner('Encoding passages into a vector space...'):
|
140 |
+
|
141 |
+
if bi_enc == 'intfloat/e5-base':
|
142 |
+
|
143 |
+
corpus_embeddings = bi_encoder.encode(['passage: ' + sentence for sentence in passages], convert_to_tensor=True)
|
144 |
+
|
145 |
+
else:
|
146 |
+
|
147 |
+
corpus_embeddings = bi_encoder.encode([passages, convert_to_tensor=True)
|
148 |
+
|
149 |
|
150 |
st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
|
151 |
|
|
|
185 |
|
186 |
return bm25
|
187 |
|
188 |
+
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1",'intfloat/e5-base',"neeva/query2query"]
|
189 |
|
190 |
def display_df_as_table(model,top_k,score='score'):
|
191 |
# Display the df with text and scores as a table
|
|
|
211 |
|
212 |
# This function will search all wikipedia articles for passages that
|
213 |
# answer the query
|
214 |
+
def search_func(query, top_k=top_k, bi_encoder_type):
|
215 |
|
216 |
global bi_encoder, cross_encoder
|
217 |
|
|
|
236 |
bm25_df = display_df_as_table(bm25_hits,top_k)
|
237 |
st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
|
238 |
|
239 |
+
if bi_encoder_type == 'intfloat/e5-base':
|
240 |
+
query = 'query: ' + query
|
241 |
##### Sematic Search #####
|
242 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
243 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|