Update app.py
Browse files
app.py
CHANGED
@@ -114,8 +114,9 @@ def preprocess_plain_text(text,window_size=3):
|
|
114 |
return passages
|
115 |
|
116 |
@st.cache(allow_output_mutation=True)
|
117 |
-
def
|
118 |
|
|
|
119 |
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
120 |
bi_encoder = SentenceTransformer(bi_enc)
|
121 |
|
@@ -128,8 +129,9 @@ def bi_encoder(bi_enc,passages):
|
|
128 |
return corpus_embeddings
|
129 |
|
130 |
@st.cache(allow_output_mutation=True)
|
131 |
-
def
|
132 |
|
|
|
133 |
#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
|
134 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
135 |
return cross_encoder
|
@@ -291,8 +293,8 @@ if search:
|
|
291 |
with st.spinner(
|
292 |
text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..."
|
293 |
):
|
294 |
-
corpus_embeddings =
|
295 |
-
|
296 |
bm25 = bm25_api(passages)
|
297 |
|
298 |
with st.spinner(
|
|
|
114 |
return passages
|
115 |
|
116 |
@st.cache(allow_output_mutation=True)
|
117 |
+
def bi_encode(bi_enc,passages):
|
118 |
|
119 |
+
global bi_encoder
|
120 |
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
121 |
bi_encoder = SentenceTransformer(bi_enc)
|
122 |
|
|
|
129 |
return corpus_embeddings
|
130 |
|
131 |
@st.cache(allow_output_mutation=True)
|
132 |
+
def cross_encode():
|
133 |
|
134 |
+
global cross_encoder
|
135 |
#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
|
136 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
137 |
return cross_encoder
|
|
|
293 |
with st.spinner(
|
294 |
text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..."
|
295 |
):
|
296 |
+
corpus_embeddings = bi_encode(bi_encoder_type,passages)
|
297 |
+
cross_enc = cross_encode()
|
298 |
bm25 = bm25_api(passages)
|
299 |
|
300 |
with st.spinner(
|