Spaces:
Sleeping
Sleeping
Commit
·
a95ef9f
1
Parent(s):
2393537
General code improvements and refinements.
Browse files- Dockerfile +0 -2
- app.py +36 -44
- requirements.txt +2 -3
- requirements_gpu.txt +3 -3
- search_funcs/bm25_functions.py +200 -77
- search_funcs/helper_functions.py +35 -6
- search_funcs/semantic_functions.py +108 -396
- search_funcs/spacy_search_funcs.py +6 -1
Dockerfile
CHANGED
|
@@ -58,7 +58,5 @@ WORKDIR $HOME/app
|
|
| 58 |
|
| 59 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 60 |
COPY --chown=user . $HOME/app
|
| 61 |
-
#COPY . $HOME/app
|
| 62 |
-
|
| 63 |
|
| 64 |
CMD ["python", "app.py"]
|
|
|
|
| 58 |
|
| 59 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 60 |
COPY --chown=user . $HOME/app
|
|
|
|
|
|
|
| 61 |
|
| 62 |
CMD ["python", "app.py"]
|
app.py
CHANGED
|
@@ -7,7 +7,7 @@ PandasDataFrame = Type[pd.DataFrame]
|
|
| 7 |
|
| 8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
| 9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
| 10 |
-
from search_funcs.semantic_functions import docs_to_bge_embed_np_array,
|
| 11 |
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
|
| 12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
| 13 |
from search_funcs.aws_functions import load_data_from_aws
|
|
@@ -17,39 +17,33 @@ temp_folder_path = get_temp_folder_path()
|
|
| 17 |
empty_folder(temp_folder_path)
|
| 18 |
|
| 19 |
## Gradio app - BM25 search
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
with block:
|
| 24 |
print("Please don't close this window! Open the below link in the web browser of your choice.")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
bm25_search_object_state = gr.State()
|
| 34 |
-
|
| 35 |
-
k_val = gr.State(9999)
|
| 36 |
-
out_passages = gr.State(9999)
|
| 37 |
-
vec_weight = gr.State(1)
|
| 38 |
-
|
| 39 |
-
corpus_state = gr.State()
|
| 40 |
-
keyword_data_list_state = gr.State([])
|
| 41 |
-
join_data_state = gr.State(pd.DataFrame())
|
| 42 |
-
output_file_state = gr.State([])
|
| 43 |
-
|
| 44 |
-
orig_keyword_data_state = gr.State(pd.DataFrame())
|
| 45 |
-
keyword_data_state = gr.State(pd.DataFrame())
|
| 46 |
-
|
| 47 |
-
orig_semantic_data_state = gr.State(pd.DataFrame())
|
| 48 |
-
semantic_data_state = gr.State(pd.DataFrame())
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
session_hash_state = gr.State("")
|
| 51 |
s3_output_folder_state = gr.State("")
|
|
|
|
|
|
|
| 52 |
|
|
|
|
| 53 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
| 54 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
| 55 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
|
@@ -167,7 +161,7 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
| 167 |
out_aws_data_message = gr.Textbox(label="AWS data load progress")
|
| 168 |
|
| 169 |
# Changing search parameters button
|
| 170 |
-
in_search_param_button.click(fn=prepare_bm25, inputs=[
|
| 171 |
|
| 172 |
# ---
|
| 173 |
in_k1_button.click(display_info, inputs=in_k1_info)
|
|
@@ -178,43 +172,41 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
| 178 |
### Loading AWS data ###
|
| 179 |
load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
|
| 180 |
load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
|
| 181 |
-
|
| 182 |
|
| 183 |
### BM25 SEARCH ###
|
| 184 |
# Update dropdowns upon initial file load
|
| 185 |
-
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column,
|
| 186 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
| 187 |
|
| 188 |
# Load in BM25 data
|
| 189 |
-
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column,
|
| 190 |
-
then(fn=prepare_bm25, inputs=[
|
| 191 |
-
|
| 192 |
|
| 193 |
# BM25 search functions on click or enter
|
| 194 |
-
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state,
|
| 195 |
-
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state,
|
| 196 |
|
| 197 |
# Fuzzy search functions on click
|
| 198 |
-
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query,
|
| 199 |
|
| 200 |
### SEMANTIC SEARCH ###
|
| 201 |
|
| 202 |
# Load in a csv/excel file for semantic search
|
| 203 |
-
in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state,
|
| 204 |
load_semantic_data_button.click(
|
| 205 |
-
csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[
|
| 206 |
-
then(docs_to_bge_embed_np_array, inputs=[
|
| 207 |
|
| 208 |
# Semantic search query
|
| 209 |
-
semantic_submit.click(
|
| 210 |
-
semantic_query.submit(
|
| 211 |
|
| 212 |
-
|
| 213 |
|
| 214 |
# Launch the Gradio app
|
| 215 |
if __name__ == "__main__":
|
| 216 |
-
|
| 217 |
|
| 218 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
| 219 |
-
#
|
| 220 |
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
|
|
|
| 7 |
|
| 8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
| 9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
| 10 |
+
from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_semantic_search
|
| 11 |
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
|
| 12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
| 13 |
from search_funcs.aws_functions import load_data_from_aws
|
|
|
|
| 17 |
empty_folder(temp_folder_path)
|
| 18 |
|
| 19 |
## Gradio app - BM25 search
|
| 20 |
+
app = gr.Blocks(theme = gr.themes.Base()) # , css="theme.css"
|
| 21 |
|
| 22 |
+
with app:
|
|
|
|
| 23 |
print("Please don't close this window! Open the below link in the web browser of your choice.")
|
| 24 |
|
| 25 |
+
# BM25 state objects
|
| 26 |
+
orig_keyword_data_state = gr.State(pd.DataFrame()) # Original data that is not changed #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
|
| 27 |
+
prepared_keyword_data_state = gr.State(pd.DataFrame()) # Data frame the contains modified data #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
|
| 28 |
+
#tokenised_prepared_keyword_data_state = gr.State([]) # This is data that has been loaded in as tokens #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State()
|
| 29 |
+
tokenised_prepared_keyword_data_state = gr.State([]) # Data that has been prepared for search (tokenised) #gr.Dataframe(np.array([]), type="array", visible=False) #gr.State([])
|
| 30 |
+
bm25_search_index_state = gr.State()
|
| 31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
# Semantic search state objects
|
| 34 |
+
orig_semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
|
| 35 |
+
semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
|
| 36 |
+
semantic_input_document_format = gr.State([])
|
| 37 |
+
embeddings_state = gr.State(np.array([])) #gr.Dataframe(np.array([]), type="numpy", visible=False) #gr.State(np.array([])) # globals()["embeddings"]
|
| 38 |
+
semantic_k_val = gr.Number(9999, visible=False)
|
| 39 |
+
|
| 40 |
+
# State objects for app in general
|
| 41 |
session_hash_state = gr.State("")
|
| 42 |
s3_output_folder_state = gr.State("")
|
| 43 |
+
join_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(), visible=False) #gr.State(pd.DataFrame())
|
| 44 |
+
output_file_state = gr.Dropdown([], visible=False, allow_custom_value=True) #gr.Dataframe(type="array", visible=False) #gr.State([])
|
| 45 |
|
| 46 |
+
# Informational state objects
|
| 47 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
| 48 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
| 49 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
|
|
|
| 161 |
out_aws_data_message = gr.Textbox(label="AWS data load progress")
|
| 162 |
|
| 163 |
# Changing search parameters button
|
| 164 |
+
in_search_param_button.click(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message])
|
| 165 |
|
| 166 |
# ---
|
| 167 |
in_k1_button.click(display_info, inputs=in_k1_info)
|
|
|
|
| 172 |
### Loading AWS data ###
|
| 173 |
load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
|
| 174 |
load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
|
|
|
|
| 175 |
|
| 176 |
### BM25 SEARCH ###
|
| 177 |
# Update dropdowns upon initial file load
|
| 178 |
+
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, prepared_keyword_data_state, orig_keyword_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, load_finished_message, current_source], api_name="initial_load")
|
| 179 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
| 180 |
|
| 181 |
# Load in BM25 data
|
| 182 |
+
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, prepared_keyword_data_state, tokenised_prepared_keyword_data_state, in_clean_data, return_intermediate_files], outputs=[tokenised_prepared_keyword_data_state, load_finished_message, prepared_keyword_data_state, output_file, output_file, in_bm25_column], api_name="load_keyword").\
|
| 183 |
+
then(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_index_state, tokenised_prepared_keyword_data_state], api_name="prepare_keyword") # keyword_data_list_state
|
|
|
|
| 184 |
|
| 185 |
# BM25 search functions on click or enter
|
| 186 |
+
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file], api_name="keyword_search")
|
| 187 |
+
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file])
|
| 188 |
|
| 189 |
# Fuzzy search functions on click
|
| 190 |
+
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, tokenised_prepared_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
|
| 191 |
|
| 192 |
### SEMANTIC SEARCH ###
|
| 193 |
|
| 194 |
# Load in a csv/excel file for semantic search
|
| 195 |
+
in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, semantic_load_progress, current_source_semantic])
|
| 196 |
load_semantic_data_button.click(
|
| 197 |
+
csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[semantic_input_document_format, semantic_load_progress, output_file_state]).\
|
| 198 |
+
then(docs_to_bge_embed_np_array, inputs=[semantic_input_document_format, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, embeddings_state, semantic_output_file, output_file_state]) # vectorstore_state
|
| 199 |
|
| 200 |
# Semantic search query
|
| 201 |
+
semantic_submit.click(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
|
| 202 |
+
semantic_query.submit(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
|
| 203 |
|
| 204 |
+
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
|
| 205 |
|
| 206 |
# Launch the Gradio app
|
| 207 |
if __name__ == "__main__":
|
| 208 |
+
app.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
|
| 209 |
|
| 210 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
| 211 |
+
# app.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
|
| 212 |
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
requirements.txt
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
pandas==2.2.2
|
| 2 |
polars==0.20.3
|
| 3 |
pyarrow==14.0.2
|
| 4 |
-
openpyxl==3.1.
|
| 5 |
torch==2.3.1
|
| 6 |
-
transformers==4.41.2
|
| 7 |
spacy
|
| 8 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
| 9 |
gradio
|
| 10 |
sentence_transformers==3.0.1
|
| 11 |
-
lxml==5.
|
| 12 |
boto3==1.34.103
|
|
|
|
| 1 |
pandas==2.2.2
|
| 2 |
polars==0.20.3
|
| 3 |
pyarrow==14.0.2
|
| 4 |
+
openpyxl==3.1.3
|
| 5 |
torch==2.3.1
|
|
|
|
| 6 |
spacy
|
| 7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
| 8 |
gradio
|
| 9 |
sentence_transformers==3.0.1
|
| 10 |
+
lxml==5.2.2
|
| 11 |
boto3==1.34.103
|
requirements_gpu.txt
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
pandas==2.2.2
|
| 2 |
polars==0.20.3
|
| 3 |
pyarrow==14.0.2
|
| 4 |
-
openpyxl==3.1.
|
| 5 |
torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
|
| 6 |
spacy
|
| 7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
| 8 |
gradio
|
| 9 |
-
sentence_transformers==
|
| 10 |
-
lxml==5.
|
| 11 |
boto3==1.34.103
|
|
|
|
| 1 |
pandas==2.2.2
|
| 2 |
polars==0.20.3
|
| 3 |
pyarrow==14.0.2
|
| 4 |
+
openpyxl==3.1.3
|
| 5 |
torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
|
| 6 |
spacy
|
| 7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
| 8 |
gradio
|
| 9 |
+
sentence_transformers==3.0.1
|
| 10 |
+
lxml==5.2.2
|
| 11 |
boto3==1.34.103
|
search_funcs/bm25_functions.py
CHANGED
|
@@ -8,6 +8,7 @@ import time
|
|
| 8 |
import pandas as pd
|
| 9 |
from numpy import inf
|
| 10 |
import gradio as gr
|
|
|
|
| 11 |
|
| 12 |
from datetime import datetime
|
| 13 |
|
|
@@ -165,7 +166,7 @@ class BM25:
|
|
| 165 |
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
|
| 166 |
|
| 167 |
|
| 168 |
-
def get_top_n_with_score(self, query, documents, n=5):
|
| 169 |
"""
|
| 170 |
Retrieve the top n documents for the query along with their scores.
|
| 171 |
|
|
@@ -229,15 +230,47 @@ class BM25:
|
|
| 229 |
with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
|
| 230 |
return pickle.load(fsave)
|
| 231 |
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, clean="No", return_intermediate_files = "No", progress=gr.Progress(track_tqdm=True)):
|
| 235 |
-
#print(in_file)
|
| 236 |
ensure_output_folder_exists(output_folder)
|
| 237 |
|
| 238 |
if not in_file:
|
| 239 |
print("No input file found. Please load in at least one file.")
|
| 240 |
-
return None, "No input file found. Please load in at least one file.", data_state, None, None,
|
| 241 |
|
| 242 |
progress(0, desc = "Loading in data")
|
| 243 |
file_list = [string.name for string in in_file]
|
|
@@ -247,25 +280,24 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 247 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
| 248 |
|
| 249 |
if not data_file_names:
|
| 250 |
-
return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None,
|
| 251 |
|
| 252 |
if not text_column:
|
| 253 |
-
return None, "Please enter a column name to search.", data_state, None, None,
|
| 254 |
|
| 255 |
data_file_name = data_file_names[0]
|
| 256 |
|
| 257 |
df = data_state #read_file(data_file_name)
|
| 258 |
-
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
| 259 |
data_file_out_name_no_ext = get_file_path_end(data_file_name)
|
| 260 |
|
| 261 |
-
## Load in pre-tokenised
|
| 262 |
-
tokenised_df = pd.DataFrame()
|
| 263 |
|
| 264 |
-
tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
| 265 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
df[text_column] = df[text_column].astype(str).str.lower()
|
| 270 |
|
| 271 |
if "copy_of_case_note_id" in df.columns:
|
|
@@ -273,10 +305,10 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 273 |
df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
|
| 274 |
|
| 275 |
if search_index_file_names:
|
| 276 |
-
|
| 277 |
message = "Tokenisation skipped - loading search index from file."
|
| 278 |
print(message)
|
| 279 |
-
return
|
| 280 |
|
| 281 |
|
| 282 |
if clean == "Yes":
|
|
@@ -285,11 +317,11 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 285 |
print("Starting data clean.")
|
| 286 |
|
| 287 |
#df = df.drop_duplicates(text_column)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
|
| 291 |
# Save to file if you have cleaned the data
|
| 292 |
-
out_file_name, text_column, df = save_prepared_bm25_data(data_file_name,
|
| 293 |
|
| 294 |
clean_toc = time.perf_counter()
|
| 295 |
clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
|
|
@@ -297,7 +329,7 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 297 |
|
| 298 |
else:
|
| 299 |
# Don't clean or save file to disk
|
| 300 |
-
|
| 301 |
print("No data cleaning performed")
|
| 302 |
out_file_name = None
|
| 303 |
|
|
@@ -305,24 +337,27 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 305 |
|
| 306 |
progress(0.4, desc = "Tokenising text")
|
| 307 |
|
|
|
|
|
|
|
| 308 |
if tokenised_state:
|
| 309 |
-
|
| 310 |
-
corpus = tokenised_df.iloc[:,0].tolist()
|
| 311 |
print("Tokenised data loaded from file")
|
| 312 |
-
|
|
|
|
| 313 |
|
| 314 |
else:
|
| 315 |
tokeniser_tic = time.perf_counter()
|
| 316 |
-
|
| 317 |
batch_size = 256
|
| 318 |
-
for doc in tokenizer.pipe(progress.tqdm(
|
| 319 |
-
|
| 320 |
|
| 321 |
tokeniser_toc = time.perf_counter()
|
| 322 |
tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
|
| 323 |
print(tokenizer_time_out)
|
|
|
|
| 324 |
|
| 325 |
-
if len(
|
| 326 |
message = "Data loaded"
|
| 327 |
else:
|
| 328 |
message = "Data loaded. Warning: dataset may be too short to get consistent search results."
|
|
@@ -334,13 +369,29 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
| 334 |
else:
|
| 335 |
tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
|
| 336 |
|
| 337 |
-
pd.DataFrame(data={"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
-
|
|
|
|
|
|
|
| 344 |
|
| 345 |
ensure_output_folder_exists(output_folder)
|
| 346 |
|
|
@@ -368,26 +419,54 @@ def save_prepared_bm25_data(in_file_name, prepared_text_list, in_df, in_bm25_col
|
|
| 368 |
|
| 369 |
return file_name, new_text_column, prepared_df
|
| 370 |
|
| 371 |
-
def prepare_bm25(
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
if not in_file:
|
| 378 |
out_message ="No input file found. Please load in at least one file."
|
| 379 |
print(out_message)
|
| 380 |
-
return out_message, None
|
| 381 |
|
| 382 |
-
if not
|
| 383 |
out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
|
| 384 |
print(out_message)
|
| 385 |
-
return out_message, None
|
| 386 |
|
| 387 |
if not text_column:
|
| 388 |
out_message = "Please enter a column name to search."
|
| 389 |
print(out_message)
|
| 390 |
-
return out_message, None
|
| 391 |
|
| 392 |
file_list = [string.name for string in in_file]
|
| 393 |
|
|
@@ -397,36 +476,23 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
| 397 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
| 398 |
|
| 399 |
if not data_file_names:
|
| 400 |
-
return "Please load in at least one csv/Excel/parquet data file.", None
|
| 401 |
|
| 402 |
data_file_name = data_file_names[0]
|
| 403 |
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
| 404 |
data_file_name_no_ext = get_file_path_end(data_file_name)
|
| 405 |
|
| 406 |
-
# Check if there is a search index file already
|
| 407 |
-
#index_file_names = [string for string in file_list if "gz" in string.lower()]
|
| 408 |
-
|
| 409 |
progress(0.6, desc = "Preparing search index")
|
| 410 |
|
| 411 |
-
#if index_file_names:
|
| 412 |
if search_index:
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
#print(index_file_name)
|
| 416 |
-
|
| 417 |
-
bm25_load = search_index
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
#index_file_out_name = get_file_path_end_with_ext(index_file_name)
|
| 421 |
-
#index_file_name_no_ext = get_file_path_end(index_file_name)
|
| 422 |
-
|
| 423 |
else:
|
| 424 |
-
print("Preparing BM25 corpus")
|
| 425 |
|
| 426 |
-
|
| 427 |
|
| 428 |
-
global bm25
|
| 429 |
-
bm25 = bm25_load
|
| 430 |
|
| 431 |
if return_intermediate_files == "Yes":
|
| 432 |
print("Saving search index file")
|
|
@@ -451,7 +517,7 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
| 451 |
|
| 452 |
print(message)
|
| 453 |
|
| 454 |
-
return message, None, bm25
|
| 455 |
|
| 456 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
| 457 |
'''
|
|
@@ -474,9 +540,75 @@ def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
|
| 474 |
|
| 475 |
return out_query
|
| 476 |
|
| 477 |
-
def bm25_search(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
progress(0, desc = "Conducting keyword search")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
# Prepare query
|
| 482 |
if (clean == "Yes") | (text_column.endswith("_cleaned")):
|
|
@@ -484,8 +616,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
| 484 |
else:
|
| 485 |
token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
|
| 486 |
|
| 487 |
-
#print(token_query)
|
| 488 |
-
|
| 489 |
# Perform search
|
| 490 |
print("Searching")
|
| 491 |
|
|
@@ -504,7 +634,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
| 504 |
|
| 505 |
# Join scores onto searched data
|
| 506 |
results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
|
| 507 |
-
|
| 508 |
|
| 509 |
|
| 510 |
# Join on data from duplicate case notes
|
|
@@ -516,33 +645,27 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
| 516 |
print("Clean is yes")
|
| 517 |
orig_text_column = text_column.replace("_cleaned", "")
|
| 518 |
|
| 519 |
-
#print(orig_text_column)
|
| 520 |
-
#print(original_data.columns)
|
| 521 |
-
|
| 522 |
original_data["original_note_id"] = original_data["copy_of_case_note_id"]
|
| 523 |
original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
|
| 524 |
|
| 525 |
results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
|
| 526 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
|
| 527 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
|
| 528 |
-
|
| 529 |
-
#results_df_out = pd.concat([results_df_out, original_data[~original_data["copy_of_case_note_id"].isna()][["copy_of_case_note_id", "person_id"]]])
|
| 530 |
-
# Replace NaN with an empty string
|
| 531 |
-
# results_df_out.fillna('', inplace=True)
|
| 532 |
-
|
| 533 |
-
|
| 534 |
|
|
|
|
|
|
|
| 535 |
# Join on additional files
|
| 536 |
if not in_join_file.empty:
|
| 537 |
progress(0.5, desc = "Joining on additional data file")
|
| 538 |
-
join_df = in_join_file
|
| 539 |
-
|
|
|
|
| 540 |
results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
|
| 541 |
|
| 542 |
# Duplicates dropped so as not to expand out dataframe
|
| 543 |
-
|
| 544 |
|
| 545 |
-
results_df_out = results_df_out.merge(
|
| 546 |
|
| 547 |
# Reorder results by score, and whether there is text
|
| 548 |
results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
|
|
@@ -559,7 +682,7 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
| 559 |
# Highlight found text and save to file
|
| 560 |
results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
|
| 561 |
results_df_out_wb.save(results_df_name)
|
| 562 |
-
|
| 563 |
results_first_text = results_df_out[text_column].iloc[0]
|
| 564 |
|
| 565 |
print("Returning results")
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
from numpy import inf
|
| 10 |
import gradio as gr
|
| 11 |
+
from typing import List
|
| 12 |
|
| 13 |
from datetime import datetime
|
| 14 |
|
|
|
|
| 166 |
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
|
| 167 |
|
| 168 |
|
| 169 |
+
def get_top_n_with_score(self, query:str, documents:List[str], n=5):
|
| 170 |
"""
|
| 171 |
Retrieve the top n documents for the query along with their scores.
|
| 172 |
|
|
|
|
| 230 |
with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
|
| 231 |
return pickle.load(fsave)
|
| 232 |
|
| 233 |
+
def prepare_bm25_input_data(
|
| 234 |
+
in_file: list,
|
| 235 |
+
text_column: str,
|
| 236 |
+
data_state: pd.DataFrame,
|
| 237 |
+
tokenised_state: list,
|
| 238 |
+
clean: str = "No",
|
| 239 |
+
return_intermediate_files: str = "No",
|
| 240 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 241 |
+
) -> tuple:
|
| 242 |
+
"""
|
| 243 |
+
Prepare BM25 input data by loading, cleaning, and tokenizing the text data.
|
| 244 |
+
|
| 245 |
+
Parameters
|
| 246 |
+
----------
|
| 247 |
+
in_file: list
|
| 248 |
+
List of input files to be processed.
|
| 249 |
+
text_column: str
|
| 250 |
+
The name of the text column in the data file to search.
|
| 251 |
+
data_state: pd.DataFrame
|
| 252 |
+
The current state of the data.
|
| 253 |
+
tokenised_state: list
|
| 254 |
+
The current state of the tokenized data.
|
| 255 |
+
clean: str, optional
|
| 256 |
+
Whether to clean the text data (default is "No").
|
| 257 |
+
return_intermediate_files: str, optional
|
| 258 |
+
Whether to return intermediate processing files (default is "No").
|
| 259 |
+
progress: gr.Progress, optional
|
| 260 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
tuple
|
| 265 |
+
A tuple containing the prepared search text list, a message, the updated data state,
|
| 266 |
+
the tokenized data, the search index, and a dropdown component for the text column.
|
| 267 |
+
"""
|
| 268 |
|
|
|
|
|
|
|
| 269 |
ensure_output_folder_exists(output_folder)
|
| 270 |
|
| 271 |
if not in_file:
|
| 272 |
print("No input file found. Please load in at least one file.")
|
| 273 |
+
return None, "No input file found. Please load in at least one file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
| 274 |
|
| 275 |
progress(0, desc = "Loading in data")
|
| 276 |
file_list = [string.name for string in in_file]
|
|
|
|
| 280 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
| 281 |
|
| 282 |
if not data_file_names:
|
| 283 |
+
return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
| 284 |
|
| 285 |
if not text_column:
|
| 286 |
+
return None, "Please enter a column name to search.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
| 287 |
|
| 288 |
data_file_name = data_file_names[0]
|
| 289 |
|
| 290 |
df = data_state #read_file(data_file_name)
|
| 291 |
+
#data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
| 292 |
data_file_out_name_no_ext = get_file_path_end(data_file_name)
|
| 293 |
|
| 294 |
+
## Load in pre-tokenised prepared_search_text_list if exists
|
| 295 |
+
#tokenised_df = pd.DataFrame()
|
| 296 |
|
| 297 |
+
#tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
| 298 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
| 299 |
|
| 300 |
+
# Set all search text to lower case
|
|
|
|
| 301 |
df[text_column] = df[text_column].astype(str).str.lower()
|
| 302 |
|
| 303 |
if "copy_of_case_note_id" in df.columns:
|
|
|
|
| 305 |
df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
|
| 306 |
|
| 307 |
if search_index_file_names:
|
| 308 |
+
prepared_search_text_list = list(df[text_column])
|
| 309 |
message = "Tokenisation skipped - loading search index from file."
|
| 310 |
print(message)
|
| 311 |
+
return prepared_search_text_list, message, df, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
| 312 |
|
| 313 |
|
| 314 |
if clean == "Yes":
|
|
|
|
| 317 |
print("Starting data clean.")
|
| 318 |
|
| 319 |
#df = df.drop_duplicates(text_column)
|
| 320 |
+
prepared_text_as_list = list(df[text_column])
|
| 321 |
+
prepared_text_as_list = initial_clean(prepared_text_as_list)
|
| 322 |
|
| 323 |
# Save to file if you have cleaned the data
|
| 324 |
+
out_file_name, text_column, df = save_prepared_bm25_data(data_file_name, prepared_text_as_list, df, text_column)
|
| 325 |
|
| 326 |
clean_toc = time.perf_counter()
|
| 327 |
clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
|
|
|
|
| 329 |
|
| 330 |
else:
|
| 331 |
# Don't clean or save file to disk
|
| 332 |
+
prepared_text_as_list = list(df[text_column])
|
| 333 |
print("No data cleaning performed")
|
| 334 |
out_file_name = None
|
| 335 |
|
|
|
|
| 337 |
|
| 338 |
progress(0.4, desc = "Tokenising text")
|
| 339 |
|
| 340 |
+
print("Tokenised state:", tokenised_state)
|
| 341 |
+
|
| 342 |
if tokenised_state:
|
| 343 |
+
prepared_search_text_list = tokenised_state.iloc[:,0].tolist()
|
|
|
|
| 344 |
print("Tokenised data loaded from file")
|
| 345 |
+
|
| 346 |
+
#print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
|
| 347 |
|
| 348 |
else:
|
| 349 |
tokeniser_tic = time.perf_counter()
|
| 350 |
+
prepared_search_text_list = []
|
| 351 |
batch_size = 256
|
| 352 |
+
for doc in tokenizer.pipe(progress.tqdm(prepared_text_as_list, desc = "Tokenising text", unit = "rows"), batch_size=batch_size):
|
| 353 |
+
prepared_search_text_list.append([token.text for token in doc])
|
| 354 |
|
| 355 |
tokeniser_toc = time.perf_counter()
|
| 356 |
tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
|
| 357 |
print(tokenizer_time_out)
|
| 358 |
+
#print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
|
| 359 |
|
| 360 |
+
if len(prepared_text_as_list) >= 20:
|
| 361 |
message = "Data loaded"
|
| 362 |
else:
|
| 363 |
message = "Data loaded. Warning: dataset may be too short to get consistent search results."
|
|
|
|
| 369 |
else:
|
| 370 |
tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
|
| 371 |
|
| 372 |
+
pd.DataFrame(data={"prepared_search_text_list":prepared_search_text_list}).to_parquet(tokenised_data_file_name)
|
| 373 |
+
|
| 374 |
+
return prepared_search_text_list, message, df, out_file_name, tokenised_data_file_name, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
|
| 375 |
+
|
| 376 |
+
return prepared_search_text_list, message, df, out_file_name, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
|
| 377 |
|
| 378 |
+
def save_prepared_bm25_data(in_file_name: str, prepared_text_list: list, in_df: pd.DataFrame, in_bm25_column: str, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> tuple:
|
| 379 |
+
"""
|
| 380 |
+
Save the prepared BM25 data to a file.
|
| 381 |
+
|
| 382 |
+
This function ensures the output folder exists, checks if the length of the prepared text list matches the input dataframe,
|
| 383 |
+
and saves the prepared data to a file in the specified format. The original column in the input dataframe is dropped to reduce file size.
|
| 384 |
|
| 385 |
+
Parameters:
|
| 386 |
+
- in_file_name (str): The name of the input file.
|
| 387 |
+
- prepared_text_list (list): The list of prepared text.
|
| 388 |
+
- in_df (pd.DataFrame): The input dataframe.
|
| 389 |
+
- in_bm25_column (str): The name of the column to be processed.
|
| 390 |
+
- progress (gr.Progress, optional): The progress tracker for the operation.
|
| 391 |
|
| 392 |
+
Returns:
|
| 393 |
+
- tuple: A tuple containing the file name, new text column name, and the prepared dataframe.
|
| 394 |
+
"""
|
| 395 |
|
| 396 |
ensure_output_folder_exists(output_folder)
|
| 397 |
|
|
|
|
| 419 |
|
| 420 |
return file_name, new_text_column, prepared_df
|
| 421 |
|
| 422 |
+
def prepare_bm25(
|
| 423 |
+
prepared_search_text_list: List[str],
|
| 424 |
+
in_file: List[gr.File],
|
| 425 |
+
text_column: str,
|
| 426 |
+
search_index: BM25,
|
| 427 |
+
clean: str,
|
| 428 |
+
return_intermediate_files: str,
|
| 429 |
+
k1: float = 1.5,
|
| 430 |
+
b: float = 0.75,
|
| 431 |
+
alpha: float = -5,
|
| 432 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 433 |
+
) -> tuple:
|
| 434 |
+
"""
|
| 435 |
+
Prepare the BM25 search index.
|
| 436 |
+
|
| 437 |
+
This function prepares the BM25 search index from the provided text list and input file. It ensures the necessary
|
| 438 |
+
files and columns are present, processes the data, and optionally saves intermediate files.
|
| 439 |
+
|
| 440 |
+
Parameters:
|
| 441 |
+
- prepared_search_text_list (List[str]): The list of prepared search text.
|
| 442 |
+
- in_file (List[gr.File]): The list of input files.
|
| 443 |
+
- text_column (str): The name of the column to search.
|
| 444 |
+
- search_index (BM25): The BM25 search index.
|
| 445 |
+
- clean (str): Indicates whether to clean the data.
|
| 446 |
+
- return_intermediate_files (str): Indicates whether to return intermediate files.
|
| 447 |
+
- k1 (float, optional): The k1 parameter for BM25. Default is 1.5.
|
| 448 |
+
- b (float, optional): The b parameter for BM25. Default is 0.75.
|
| 449 |
+
- alpha (float, optional): The alpha parameter for BM25. Default is -5.
|
| 450 |
+
- progress (gr.Progress, optional): The progress tracker for the operation.
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
- tuple: A tuple containing the output message, BM25 search index, and other relevant information.
|
| 454 |
+
"""
|
| 455 |
|
| 456 |
if not in_file:
|
| 457 |
out_message ="No input file found. Please load in at least one file."
|
| 458 |
print(out_message)
|
| 459 |
+
return out_message, None, None
|
| 460 |
|
| 461 |
+
if not prepared_search_text_list:
|
| 462 |
out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
|
| 463 |
print(out_message)
|
| 464 |
+
return out_message, None, None, None
|
| 465 |
|
| 466 |
if not text_column:
|
| 467 |
out_message = "Please enter a column name to search."
|
| 468 |
print(out_message)
|
| 469 |
+
return out_message, None, None, None
|
| 470 |
|
| 471 |
file_list = [string.name for string in in_file]
|
| 472 |
|
|
|
|
| 476 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
| 477 |
|
| 478 |
if not data_file_names:
|
| 479 |
+
return "Please load in at least one csv/Excel/parquet data file.", None, None, None
|
| 480 |
|
| 481 |
data_file_name = data_file_names[0]
|
| 482 |
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
| 483 |
data_file_name_no_ext = get_file_path_end(data_file_name)
|
| 484 |
|
|
|
|
|
|
|
|
|
|
| 485 |
progress(0.6, desc = "Preparing search index")
|
| 486 |
|
|
|
|
| 487 |
if search_index:
|
| 488 |
+
bm25 = search_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
else:
|
| 490 |
+
print("Preparing BM25 search corpus")
|
| 491 |
|
| 492 |
+
bm25 = BM25(prepared_search_text_list, k1=k1, b=b, alpha=alpha)
|
| 493 |
|
| 494 |
+
#global bm25
|
| 495 |
+
#bm25 = bm25_load
|
| 496 |
|
| 497 |
if return_intermediate_files == "Yes":
|
| 498 |
print("Saving search index file")
|
|
|
|
| 517 |
|
| 518 |
print(message)
|
| 519 |
|
| 520 |
+
return message, None, bm25, prepared_search_text_list
|
| 521 |
|
| 522 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
| 523 |
'''
|
|
|
|
| 540 |
|
| 541 |
return out_query
|
| 542 |
|
| 543 |
+
def bm25_search(
|
| 544 |
+
free_text_query: str,
|
| 545 |
+
in_no_search_results: int,
|
| 546 |
+
original_data: pd.DataFrame,
|
| 547 |
+
searched_data: pd.DataFrame,
|
| 548 |
+
text_column: str,
|
| 549 |
+
in_join_file: str,
|
| 550 |
+
clean: str,
|
| 551 |
+
bm25: BM25,
|
| 552 |
+
prepared_search_text_list_state: list,
|
| 553 |
+
in_join_column: str = "",
|
| 554 |
+
search_df_join_column: str = "",
|
| 555 |
+
k1: float = 1.5,
|
| 556 |
+
b: float = 0.75,
|
| 557 |
+
alpha: float = -5,
|
| 558 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 559 |
+
) -> tuple:
|
| 560 |
+
"""
|
| 561 |
+
Perform a BM25 search on the provided text data.
|
| 562 |
+
|
| 563 |
+
Parameters
|
| 564 |
+
----------
|
| 565 |
+
free_text_query : str
|
| 566 |
+
The query text to search for.
|
| 567 |
+
in_no_search_results : int
|
| 568 |
+
The number of search results to return.
|
| 569 |
+
original_data : pd.DataFrame
|
| 570 |
+
The original data containing the text to be searched.
|
| 571 |
+
searched_data : pd.DataFrame
|
| 572 |
+
The data that has been prepared for searching.
|
| 573 |
+
text_column : str
|
| 574 |
+
The name of the column in the data to search.
|
| 575 |
+
in_join_file : str
|
| 576 |
+
The file to join the search results with.
|
| 577 |
+
clean : str
|
| 578 |
+
Whether to clean the text data.
|
| 579 |
+
bm25 : BM25
|
| 580 |
+
The BM25 object used for searching.
|
| 581 |
+
prepared_search_text_list_state : list
|
| 582 |
+
The state of the prepared search text list.
|
| 583 |
+
in_join_column : str, optional
|
| 584 |
+
The column to join on in the input file (default is "").
|
| 585 |
+
search_df_join_column : str, optional
|
| 586 |
+
The column to join on in the search dataframe (default is "").
|
| 587 |
+
k1 : float, optional
|
| 588 |
+
The k1 parameter for BM25 (default is 1.5).
|
| 589 |
+
b : float, optional
|
| 590 |
+
The b parameter for BM25 (default is 0.75).
|
| 591 |
+
alpha : float, optional
|
| 592 |
+
The alpha parameter for BM25 (default is -5).
|
| 593 |
+
progress : gr.Progress, optional
|
| 594 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
| 595 |
+
|
| 596 |
+
Returns
|
| 597 |
+
-------
|
| 598 |
+
tuple
|
| 599 |
+
A tuple containing a message, the search results file name (if any), the BM25 object, and the prepared search text list.
|
| 600 |
+
"""
|
| 601 |
|
| 602 |
progress(0, desc = "Conducting keyword search")
|
| 603 |
+
|
| 604 |
+
print("in_join_file at start of bm25_search:", in_join_file)
|
| 605 |
+
|
| 606 |
+
if not bm25:
|
| 607 |
+
print("Preparing BM25 search corpus")
|
| 608 |
+
|
| 609 |
+
bm25 = BM25(prepared_search_text_list_state, k1=k1, b=b, alpha=alpha)
|
| 610 |
+
|
| 611 |
+
# print("bm25:", bm25)
|
| 612 |
|
| 613 |
# Prepare query
|
| 614 |
if (clean == "Yes") | (text_column.endswith("_cleaned")):
|
|
|
|
| 616 |
else:
|
| 617 |
token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
|
| 618 |
|
|
|
|
|
|
|
| 619 |
# Perform search
|
| 620 |
print("Searching")
|
| 621 |
|
|
|
|
| 634 |
|
| 635 |
# Join scores onto searched data
|
| 636 |
results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
|
|
|
|
| 637 |
|
| 638 |
|
| 639 |
# Join on data from duplicate case notes
|
|
|
|
| 645 |
print("Clean is yes")
|
| 646 |
orig_text_column = text_column.replace("_cleaned", "")
|
| 647 |
|
|
|
|
|
|
|
|
|
|
| 648 |
original_data["original_note_id"] = original_data["copy_of_case_note_id"]
|
| 649 |
original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
|
| 650 |
|
| 651 |
results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
|
| 652 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
|
| 653 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
|
| 655 |
+
print("in_join_file:", in_join_file)
|
| 656 |
+
|
| 657 |
# Join on additional files
|
| 658 |
if not in_join_file.empty:
|
| 659 |
progress(0.5, desc = "Joining on additional data file")
|
| 660 |
+
#join_df = in_join_file
|
| 661 |
+
# Prepare join columns as string and remove .0 at end of stringified numbers
|
| 662 |
+
in_join_file[in_join_column] = in_join_file[in_join_column].astype(str).str.replace("\.0$","", regex=True)
|
| 663 |
results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
|
| 664 |
|
| 665 |
# Duplicates dropped so as not to expand out dataframe
|
| 666 |
+
in_join_file = in_join_file.drop_duplicates(in_join_column)
|
| 667 |
|
| 668 |
+
results_df_out = results_df_out.merge(in_join_file,left_on=search_df_join_column, right_on=in_join_column, how="left", suffixes=('','_y'))#.drop(in_join_column, axis=1)
|
| 669 |
|
| 670 |
# Reorder results by score, and whether there is text
|
| 671 |
results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
|
|
|
|
| 682 |
# Highlight found text and save to file
|
| 683 |
results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
|
| 684 |
results_df_out_wb.save(results_df_name)
|
| 685 |
+
|
| 686 |
results_first_text = results_df_out[text_column].iloc[0]
|
| 687 |
|
| 688 |
print("Returning results")
|
search_funcs/helper_functions.py
CHANGED
|
@@ -9,6 +9,8 @@ import gzip
|
|
| 9 |
import pickle
|
| 10 |
import numpy as np
|
| 11 |
|
|
|
|
|
|
|
| 12 |
# Openpyxl functions for output
|
| 13 |
from openpyxl import Workbook
|
| 14 |
from openpyxl.cell.text import InlineFont
|
|
@@ -175,15 +177,15 @@ def read_file(filename):
|
|
| 175 |
|
| 176 |
return file
|
| 177 |
|
| 178 |
-
def initial_data_load(in_file):
|
| 179 |
'''
|
| 180 |
-
When file is loaded, update the column dropdown choices
|
| 181 |
'''
|
| 182 |
new_choices = []
|
| 183 |
concat_choices = []
|
| 184 |
index_load = None
|
| 185 |
embed_load = np.array([])
|
| 186 |
-
tokenised_load =[]
|
| 187 |
out_message = ""
|
| 188 |
current_source = ""
|
| 189 |
df = pd.DataFrame()
|
|
@@ -257,7 +259,7 @@ def initial_data_load(in_file):
|
|
| 257 |
|
| 258 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
|
| 259 |
|
| 260 |
-
def put_columns_in_join_df(in_file):
|
| 261 |
'''
|
| 262 |
When file is loaded, update the column dropdown choices
|
| 263 |
'''
|
|
@@ -354,7 +356,20 @@ def highlight_found_text(search_text: str, full_text: str) -> str:
|
|
| 354 |
|
| 355 |
return "".join(pos_tokens), combined_positions
|
| 356 |
|
| 357 |
-
def create_rich_text_cell_from_positions(full_text, combined_positions):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
# Construct pos_tokens
|
| 359 |
red = InlineFont(color='00FF0000')
|
| 360 |
rich_text_cell = CellRichText()
|
|
@@ -369,7 +384,21 @@ def create_rich_text_cell_from_positions(full_text, combined_positions):
|
|
| 369 |
|
| 370 |
return rich_text_cell
|
| 371 |
|
| 372 |
-
def create_highlighted_excel_wb(df, search_text, column_to_highlight):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
# Create a new Excel workbook
|
| 375 |
wb = Workbook()
|
|
|
|
| 9 |
import pickle
|
| 10 |
import numpy as np
|
| 11 |
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
# Openpyxl functions for output
|
| 15 |
from openpyxl import Workbook
|
| 16 |
from openpyxl.cell.text import InlineFont
|
|
|
|
| 177 |
|
| 178 |
return file
|
| 179 |
|
| 180 |
+
def initial_data_load(in_file:List[str]):
|
| 181 |
'''
|
| 182 |
+
When file is loaded, update the column dropdown choices and relevant state variables
|
| 183 |
'''
|
| 184 |
new_choices = []
|
| 185 |
concat_choices = []
|
| 186 |
index_load = None
|
| 187 |
embed_load = np.array([])
|
| 188 |
+
tokenised_load = []
|
| 189 |
out_message = ""
|
| 190 |
current_source = ""
|
| 191 |
df = pd.DataFrame()
|
|
|
|
| 259 |
|
| 260 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
|
| 261 |
|
| 262 |
+
def put_columns_in_join_df(in_file:str):
|
| 263 |
'''
|
| 264 |
When file is loaded, update the column dropdown choices
|
| 265 |
'''
|
|
|
|
| 356 |
|
| 357 |
return "".join(pos_tokens), combined_positions
|
| 358 |
|
| 359 |
+
def create_rich_text_cell_from_positions(full_text: str, combined_positions: list[tuple[int, int]]) -> CellRichText:
|
| 360 |
+
"""
|
| 361 |
+
Create a rich text cell with highlighted positions.
|
| 362 |
+
|
| 363 |
+
This function takes the full text and a list of combined positions, and creates a rich text cell
|
| 364 |
+
with the specified positions highlighted in red.
|
| 365 |
+
|
| 366 |
+
Parameters:
|
| 367 |
+
full_text (str): The full text to be processed.
|
| 368 |
+
combined_positions (list[tuple[int, int]]): A list of tuples representing the start and end positions to be highlighted.
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
CellRichText: The created rich text cell with highlighted positions.
|
| 372 |
+
"""
|
| 373 |
# Construct pos_tokens
|
| 374 |
red = InlineFont(color='00FF0000')
|
| 375 |
rich_text_cell = CellRichText()
|
|
|
|
| 384 |
|
| 385 |
return rich_text_cell
|
| 386 |
|
| 387 |
+
def create_highlighted_excel_wb(df: pd.DataFrame, search_text: str, column_to_highlight: str) -> Workbook:
|
| 388 |
+
"""
|
| 389 |
+
Create a new Excel workbook with highlighted search text.
|
| 390 |
+
|
| 391 |
+
This function takes a DataFrame, a search text, and a column name to highlight. It creates a new Excel workbook,
|
| 392 |
+
highlights the occurrences of the search text in the specified column, and returns the workbook.
|
| 393 |
+
|
| 394 |
+
Parameters:
|
| 395 |
+
df (pd.DataFrame): The DataFrame containing the data to be written to the Excel workbook.
|
| 396 |
+
search_text (str): The text to search for and highlight in the specified column.
|
| 397 |
+
column_to_highlight (str): The name of the column in which to highlight the search text.
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
Workbook: The created Excel workbook with highlighted search text.
|
| 401 |
+
"""
|
| 402 |
|
| 403 |
# Create a new Excel workbook
|
| 404 |
wb = Workbook()
|
search_funcs/semantic_functions.py
CHANGED
|
@@ -5,11 +5,10 @@ from typing import Type
|
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
from datetime import datetime
|
| 8 |
-
|
| 9 |
-
from
|
| 10 |
-
#import torch
|
| 11 |
-
from torch import cuda, backends#, tensor, mm, utils
|
| 12 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 13 |
|
| 14 |
today_rev = datetime.now().strftime("%Y%m%d")
|
| 15 |
|
|
@@ -25,22 +24,6 @@ else:
|
|
| 25 |
|
| 26 |
print("Device used is: ", torch_device)
|
| 27 |
|
| 28 |
-
from search_funcs.helper_functions import create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
|
| 29 |
-
|
| 30 |
-
PandasDataFrame = Type[pd.DataFrame]
|
| 31 |
-
|
| 32 |
-
# Load embeddings - Jina - deprecated
|
| 33 |
-
# Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
|
| 34 |
-
# Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
|
| 35 |
-
# embeddings_name = "jinaai/jina-embeddings-v2-small-en"
|
| 36 |
-
# local_embeddings_location = "model/jina/"
|
| 37 |
-
# revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
|
| 38 |
-
|
| 39 |
-
# try:
|
| 40 |
-
# embeddings_model = AutoModel.from_pretrained(local_embeddings_location, revision = revision_choice, trust_remote_code=True,local_files_only=True, device_map="auto")
|
| 41 |
-
# except:
|
| 42 |
-
# embeddings_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True, device_map="auto")
|
| 43 |
-
|
| 44 |
# Load embeddings
|
| 45 |
embeddings_name = "BAAI/bge-small-en-v1.5"
|
| 46 |
|
|
@@ -65,32 +48,53 @@ else:
|
|
| 65 |
embeddings_model = SentenceTransformer(embeddings_name)
|
| 66 |
print("Could not find local model installation. Downloading from Huggingface")
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
ensure_output_folder_exists(output_folder)
|
| 74 |
|
| 75 |
if not in_file:
|
| 76 |
out_message = "No input file found. Please load in at least one file."
|
| 77 |
print(out_message)
|
| 78 |
-
return out_message, None, None, output_file_state
|
| 79 |
-
|
| 80 |
|
| 81 |
progress(0.6, desc = "Loading/creating embeddings")
|
| 82 |
|
| 83 |
print(f"> Total split documents: {len(docs_out)}")
|
| 84 |
|
| 85 |
-
#print(docs_out)
|
| 86 |
-
|
| 87 |
page_contents = [doc.page_content for doc in docs_out]
|
| 88 |
|
| 89 |
## Load in pre-embedded file if exists
|
| 90 |
file_list = [string.name for string in in_file]
|
| 91 |
|
| 92 |
-
#print(file_list)
|
| 93 |
-
|
| 94 |
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
| 95 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
| 96 |
data_file_name = data_file_names[0]
|
|
@@ -98,22 +102,12 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
|
|
| 98 |
|
| 99 |
out_message = "Document processing complete. Ready to search."
|
| 100 |
|
| 101 |
-
# print("embeddings loaded: ", embeddings_out)
|
| 102 |
|
| 103 |
if embeddings_state.size == 0:
|
| 104 |
tic = time.perf_counter()
|
| 105 |
print("Starting to embed documents.")
|
| 106 |
-
#embeddings_list = []
|
| 107 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
| 108 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
#embeddings_out = calc_bge_norm_embeddings(page_contents, embeddings_model, tokenizer)
|
| 113 |
|
| 114 |
embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
|
| 115 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
| 116 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
| 117 |
|
| 118 |
toc = time.perf_counter()
|
| 119 |
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
|
@@ -147,31 +141,43 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
|
|
| 147 |
|
| 148 |
return out_message, embeddings_out, output_file_state, output_file_state
|
| 149 |
|
| 150 |
-
def process_data_from_scores_df(
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
docs_scores = df_docs["distances"] #.astype(float)
|
| 169 |
|
| 170 |
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
| 171 |
score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
|
| 172 |
-
#docs_keep = create_docs_keep_from_df(score_more_limit) #list(compress(docs, score_more_limit))
|
| 173 |
-
|
| 174 |
-
#print(docs_keep)
|
| 175 |
|
| 176 |
if score_more_limit.empty:
|
| 177 |
return pd.DataFrame()
|
|
@@ -179,26 +185,17 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
| 179 |
# Only keep sources that are at least 100 characters long
|
| 180 |
docs_len = score_more_limit["documents"].str.len() >= 100
|
| 181 |
|
| 182 |
-
#print(docs_len)
|
| 183 |
-
|
| 184 |
length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
|
| 185 |
-
#docs_keep = create_docs_keep_from_df(length_more_limit) #list(compress(docs_keep, length_more_limit))
|
| 186 |
-
|
| 187 |
-
#print(length_more_limit)
|
| 188 |
|
| 189 |
if length_more_limit.empty:
|
| 190 |
return pd.DataFrame()
|
| 191 |
|
| 192 |
length_more_limit['ids'] = length_more_limit['ids'].astype(int)
|
| 193 |
|
| 194 |
-
#length_more_limit.to_csv("length_more_limit.csv", index = None)
|
| 195 |
|
| 196 |
# Explode the 'metadatas' dictionary into separate columns
|
| 197 |
df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
|
| 198 |
|
| 199 |
-
#print(length_more_limit)
|
| 200 |
-
#print(df_metadata_expanded)
|
| 201 |
-
|
| 202 |
# Concatenate the original DataFrame with the expanded metadata DataFrame
|
| 203 |
results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
|
| 204 |
|
|
@@ -208,9 +205,6 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
| 208 |
results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
|
| 209 |
|
| 210 |
|
| 211 |
-
# Join back to original df
|
| 212 |
-
# results_df_out = orig_df.merge(length_more_limit[['ids', 'distances']], left_index = True, right_on = "ids", how="inner").sort_values("distances")
|
| 213 |
-
|
| 214 |
# Join on additional files
|
| 215 |
if not in_join_file.empty:
|
| 216 |
progress(0.5, desc = "Joining on additional data file")
|
|
@@ -227,68 +221,73 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
| 227 |
|
| 228 |
return results_df_out
|
| 229 |
|
| 230 |
-
def
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
# print("vectorstore loaded: ", vectorstore)
|
| 234 |
progress(0, desc = "Conducting semantic search")
|
| 235 |
|
| 236 |
ensure_output_folder_exists(output_folder)
|
| 237 |
|
| 238 |
print("Searching")
|
| 239 |
|
| 240 |
-
# Convert it to a PyTorch tensor and transfer to GPU
|
| 241 |
-
#vectorstore_tensor = tensor(vectorstore).to(device)
|
| 242 |
-
|
| 243 |
# Load the sentence transformer model and move it to GPU
|
| 244 |
-
|
| 245 |
|
| 246 |
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
| 247 |
-
query =
|
| 248 |
-
|
| 249 |
-
# query = calc_bge_norm_embeddings(query_str, embeddings_model=embeddings_model, tokenizer=tokenizer)
|
| 250 |
-
|
| 251 |
-
#query_tensor = tensor(query).to(device)
|
| 252 |
-
|
| 253 |
-
# if query_tensor.dim() == 1:
|
| 254 |
-
# query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
|
| 255 |
|
| 256 |
# Sentence transformers method, not used:
|
| 257 |
-
cosine_similarities = query @
|
| 258 |
-
#cosine_similarities = util.cos_sim(query_tensor, vectorstore_tensor)[0]
|
| 259 |
-
#top_results = torch.topk(cos_scores, k=top_k)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# Normalize the query tensor and vectorstore tensor
|
| 263 |
-
#query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
|
| 264 |
-
#vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
|
| 265 |
-
|
| 266 |
-
# Calculate cosine similarities (batch processing)
|
| 267 |
-
#cosine_similarities = mm(query_norm, vectorstore_norm.T)
|
| 268 |
-
#cosine_similarities = mm(query_tensor, vectorstore_tensor.T)
|
| 269 |
|
| 270 |
# Flatten the tensor to a 1D array
|
| 271 |
cosine_similarities = cosine_similarities.flatten()
|
| 272 |
|
| 273 |
-
# Convert to a NumPy array if it's still a PyTorch tensor
|
| 274 |
-
#cosine_similarities = cosine_similarities.cpu().numpy()
|
| 275 |
-
|
| 276 |
# Create a Pandas Series
|
| 277 |
cosine_similarities_series = pd.Series(cosine_similarities)
|
| 278 |
|
| 279 |
-
# Pull out relevent info from
|
| 280 |
-
page_contents = [doc.page_content for doc in
|
| 281 |
-
page_meta = [doc.metadata for doc in
|
| 282 |
ids_range = range(0,len(page_contents))
|
| 283 |
ids = [str(element) for element in ids_range]
|
| 284 |
|
| 285 |
-
|
| 286 |
"documents": page_contents,
|
| 287 |
"metadatas":page_meta,
|
| 288 |
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
| 289 |
|
| 290 |
|
| 291 |
-
results_df_out = process_data_from_scores_df(
|
| 292 |
|
| 293 |
print("Search complete")
|
| 294 |
|
|
@@ -312,291 +311,4 @@ def bge_simple_retrieval(query_str:str, vectorstore, docs, orig_df_col:str, k_va
|
|
| 312 |
|
| 313 |
print("Returning results")
|
| 314 |
|
| 315 |
-
return results_first_text, results_df_name
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
def docs_to_jina_embed_np_array_deprecated(docs_out, in_file, embeddings_state, return_intermediate_files = "No", embeddings_super_compress = "No", embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)):
|
| 319 |
-
'''
|
| 320 |
-
Takes a Langchain document class and saves it into a Chroma sqlite file.
|
| 321 |
-
'''
|
| 322 |
-
if not in_file:
|
| 323 |
-
out_message = "No input file found. Please load in at least one file."
|
| 324 |
-
print(out_message)
|
| 325 |
-
return out_message, None, None
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
progress(0.6, desc = "Loading/creating embeddings")
|
| 329 |
-
|
| 330 |
-
print(f"> Total split documents: {len(docs_out)}")
|
| 331 |
-
|
| 332 |
-
#print(docs_out)
|
| 333 |
-
|
| 334 |
-
page_contents = [doc.page_content for doc in docs_out]
|
| 335 |
-
|
| 336 |
-
## Load in pre-embedded file if exists
|
| 337 |
-
file_list = [string.name for string in in_file]
|
| 338 |
-
|
| 339 |
-
#print(file_list)
|
| 340 |
-
|
| 341 |
-
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
| 342 |
-
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
| 343 |
-
data_file_name = data_file_names[0]
|
| 344 |
-
data_file_name_no_ext = get_file_path_end(data_file_name)
|
| 345 |
-
|
| 346 |
-
out_message = "Document processing complete. Ready to search."
|
| 347 |
-
|
| 348 |
-
# print("embeddings loaded: ", embeddings_out)
|
| 349 |
-
|
| 350 |
-
if embeddings_state.size == 0:
|
| 351 |
-
tic = time.perf_counter()
|
| 352 |
-
print("Starting to embed documents.")
|
| 353 |
-
#embeddings_list = []
|
| 354 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
| 355 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
| 356 |
-
|
| 357 |
-
embeddings_out = embeddings.encode(sentences=page_contents, max_length=1024, show_progress_bar = True, batch_size = 32) # For Jina embeddings
|
| 358 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
| 359 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
| 360 |
-
|
| 361 |
-
toc = time.perf_counter()
|
| 362 |
-
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
| 363 |
-
print(time_out)
|
| 364 |
-
|
| 365 |
-
# If you want to save your files for next time
|
| 366 |
-
if return_intermediate_files == "Yes":
|
| 367 |
-
progress(0.9, desc = "Saving embeddings to file")
|
| 368 |
-
if embeddings_super_compress == "No":
|
| 369 |
-
semantic_search_file_name = data_file_name_no_ext + '_' + 'embeddings.npz'
|
| 370 |
-
np.savez_compressed(semantic_search_file_name, embeddings_out)
|
| 371 |
-
else:
|
| 372 |
-
semantic_search_file_name = data_file_name_no_ext + '_' + 'embedding_compress.npz'
|
| 373 |
-
embeddings_out_round = np.round(embeddings_out, 3)
|
| 374 |
-
embeddings_out_round *= 100 # Rounding not currently used
|
| 375 |
-
np.savez_compressed(semantic_search_file_name, embeddings_out_round)
|
| 376 |
-
|
| 377 |
-
return out_message, embeddings_out, semantic_search_file_name
|
| 378 |
-
|
| 379 |
-
return out_message, embeddings_out, None
|
| 380 |
-
else:
|
| 381 |
-
# Just return existing embeddings if already exist
|
| 382 |
-
embeddings_out = embeddings_state
|
| 383 |
-
|
| 384 |
-
print(out_message)
|
| 385 |
-
|
| 386 |
-
return out_message, embeddings_out, None#, None
|
| 387 |
-
|
| 388 |
-
def jina_simple_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
|
| 389 |
-
vec_score_cut_off:float, vec_weight:float, in_join_file, in_join_column = None, search_df_join_column = None, device = torch_device, embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)): # ,vectorstore, embeddings
|
| 390 |
-
|
| 391 |
-
# print("vectorstore loaded: ", vectorstore)
|
| 392 |
-
progress(0, desc = "Conducting semantic search")
|
| 393 |
-
|
| 394 |
-
print("Searching")
|
| 395 |
-
|
| 396 |
-
# Convert it to a PyTorch tensor and transfer to GPU
|
| 397 |
-
vectorstore_tensor = tensor(vectorstore).to(device)
|
| 398 |
-
|
| 399 |
-
# Load the sentence transformer model and move it to GPU
|
| 400 |
-
embeddings = embeddings.to(device)
|
| 401 |
-
|
| 402 |
-
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
| 403 |
-
query = embeddings.encode(query_str)
|
| 404 |
-
query_tensor = tensor(query).to(device)
|
| 405 |
-
|
| 406 |
-
if query_tensor.dim() == 1:
|
| 407 |
-
query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
|
| 408 |
-
|
| 409 |
-
# Normalize the query tensor and vectorstore tensor
|
| 410 |
-
query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
|
| 411 |
-
vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
|
| 412 |
-
|
| 413 |
-
# Calculate cosine similarities (batch processing)
|
| 414 |
-
cosine_similarities = mm(query_norm, vectorstore_norm.T)
|
| 415 |
-
|
| 416 |
-
# Flatten the tensor to a 1D array
|
| 417 |
-
cosine_similarities = cosine_similarities.flatten()
|
| 418 |
-
|
| 419 |
-
# Convert to a NumPy array if it's still a PyTorch tensor
|
| 420 |
-
cosine_similarities = cosine_similarities.cpu().numpy()
|
| 421 |
-
|
| 422 |
-
# Create a Pandas Series
|
| 423 |
-
cosine_similarities_series = pd.Series(cosine_similarities)
|
| 424 |
-
|
| 425 |
-
# Pull out relevent info from docs
|
| 426 |
-
page_contents = [doc.page_content for doc in docs]
|
| 427 |
-
page_meta = [doc.metadata for doc in docs]
|
| 428 |
-
ids_range = range(0,len(page_contents))
|
| 429 |
-
ids = [str(element) for element in ids_range]
|
| 430 |
-
|
| 431 |
-
df_docs = pd.DataFrame(data={"ids": ids,
|
| 432 |
-
"documents": page_contents,
|
| 433 |
-
"metadatas":page_meta,
|
| 434 |
-
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
|
| 438 |
-
|
| 439 |
-
print("Search complete")
|
| 440 |
-
|
| 441 |
-
# If nothing found, return error message
|
| 442 |
-
if results_df_out.empty:
|
| 443 |
-
return 'No result found!', None
|
| 444 |
-
|
| 445 |
-
query_str_file = query_str.replace(" ", "_")
|
| 446 |
-
|
| 447 |
-
results_df_name = "semantic_search_result_" + today_rev + "_" + query_str_file + ".xlsx"
|
| 448 |
-
|
| 449 |
-
print("Saving search output to file")
|
| 450 |
-
progress(0.7, desc = "Saving search output to file")
|
| 451 |
-
|
| 452 |
-
results_df_out.to_excel(results_df_name, index= None)
|
| 453 |
-
results_first_text = results_df_out.iloc[0, 1]
|
| 454 |
-
|
| 455 |
-
print("Returning results")
|
| 456 |
-
|
| 457 |
-
return results_first_text, results_df_name
|
| 458 |
-
|
| 459 |
-
# Deprecated Chroma functions - kept just in case needed in future.
|
| 460 |
-
# Chroma support is currently deprecated
|
| 461 |
-
# Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
|
| 462 |
-
#import chromadb
|
| 463 |
-
#from chromadb.config import Settings
|
| 464 |
-
#from typing_extensions import Protocol
|
| 465 |
-
#from chromadb import Documents, EmbeddingFunction, Embeddings
|
| 466 |
-
|
| 467 |
-
# Remove Chroma database file. If it exists as it can cause issues
|
| 468 |
-
#chromadb_file = "chroma.sqlite3"
|
| 469 |
-
|
| 470 |
-
#if os.path.isfile(chromadb_file):
|
| 471 |
-
# os.remove(chromadb_file)
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
def docs_to_chroma_save_deprecated(docs_out, embeddings = embeddings_model, progress=gr.Progress()):
|
| 475 |
-
'''
|
| 476 |
-
Takes a Langchain document class and saves it into a Chroma sqlite file. Not currently used.
|
| 477 |
-
'''
|
| 478 |
-
|
| 479 |
-
print(f"> Total split documents: {len(docs_out)}")
|
| 480 |
-
|
| 481 |
-
#print(docs_out)
|
| 482 |
-
|
| 483 |
-
page_contents = [doc.page_content for doc in docs_out]
|
| 484 |
-
page_meta = [doc.metadata for doc in docs_out]
|
| 485 |
-
ids_range = range(0,len(page_contents))
|
| 486 |
-
ids = [str(element) for element in ids_range]
|
| 487 |
-
|
| 488 |
-
tic = time.perf_counter()
|
| 489 |
-
#embeddings_list = []
|
| 490 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
| 491 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
| 492 |
-
|
| 493 |
-
embeddings_list = embeddings.encode(sentences=page_contents, max_length=256, show_progress_bar = True, batch_size = 32).tolist() # For Jina embeddings
|
| 494 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
| 495 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
| 496 |
-
|
| 497 |
-
toc = time.perf_counter()
|
| 498 |
-
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
| 499 |
-
|
| 500 |
-
#pd.Series(embeddings_list).to_csv("embeddings_out.csv")
|
| 501 |
-
|
| 502 |
-
# Jina tiny
|
| 503 |
-
# This takes about 300 seconds for 240,000 records = 800 / second, 1024 max length
|
| 504 |
-
# For 50k records:
|
| 505 |
-
# 61 seconds at 1024 max length
|
| 506 |
-
# 55 seconds at 512 max length
|
| 507 |
-
# 43 seconds at 256 max length
|
| 508 |
-
# 31 seconds at 128 max length
|
| 509 |
-
|
| 510 |
-
# The embedding took 1372.5 seconds at 256 max length for 655,020 case notes
|
| 511 |
-
|
| 512 |
-
# BGE small
|
| 513 |
-
# 96 seconds for 50k records at 512 length
|
| 514 |
-
|
| 515 |
-
# all-MiniLM-L6-v2
|
| 516 |
-
# 42.5 seconds at (256?) max length
|
| 517 |
-
|
| 518 |
-
# paraphrase-MiniLM-L3-v2
|
| 519 |
-
# 22 seconds for 128 max length
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
print(time_out)
|
| 523 |
-
|
| 524 |
-
chroma_tic = time.perf_counter()
|
| 525 |
-
|
| 526 |
-
# Create a new Chroma collection to store the documents and metadata. We don't need to specify an embedding fuction, and the default will be used.
|
| 527 |
-
client = chromadb.PersistentClient(path="./last_year", settings=Settings(
|
| 528 |
-
anonymized_telemetry=False))
|
| 529 |
-
|
| 530 |
-
try:
|
| 531 |
-
print("Deleting existing collection.")
|
| 532 |
-
#collection = client.get_collection(name="my_collection")
|
| 533 |
-
client.delete_collection(name="my_collection")
|
| 534 |
-
print("Creating new collection.")
|
| 535 |
-
collection = client.create_collection(name="my_collection")
|
| 536 |
-
except:
|
| 537 |
-
print("Creating new collection.")
|
| 538 |
-
collection = client.create_collection(name="my_collection")
|
| 539 |
-
|
| 540 |
-
# Match batch size is about 40,000, so add that amount in a loop
|
| 541 |
-
def create_batch_ranges(in_list, batch_size=40000):
|
| 542 |
-
total_rows = len(in_list)
|
| 543 |
-
ranges = []
|
| 544 |
-
|
| 545 |
-
for start in range(0, total_rows, batch_size):
|
| 546 |
-
end = min(start + batch_size, total_rows)
|
| 547 |
-
ranges.append(range(start, end))
|
| 548 |
-
|
| 549 |
-
return ranges
|
| 550 |
-
|
| 551 |
-
batch_ranges = create_batch_ranges(embeddings_list)
|
| 552 |
-
print(batch_ranges)
|
| 553 |
-
|
| 554 |
-
for row_range in progress.tqdm(batch_ranges, desc = "Creating vector database", unit = "batches of 40,000 rows"):
|
| 555 |
-
|
| 556 |
-
collection.add(
|
| 557 |
-
documents = page_contents[row_range[0]:row_range[-1]],
|
| 558 |
-
embeddings = embeddings_list[row_range[0]:row_range[-1]],
|
| 559 |
-
metadatas = page_meta[row_range[0]:row_range[-1]],
|
| 560 |
-
ids = ids[row_range[0]:row_range[-1]])
|
| 561 |
-
#print("Here")
|
| 562 |
-
|
| 563 |
-
# print(collection.count())
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
#chatf.vectorstore = vectorstore_func
|
| 567 |
-
|
| 568 |
-
chroma_toc = time.perf_counter()
|
| 569 |
-
|
| 570 |
-
chroma_time_out = f"Loading to Chroma db took {chroma_toc - chroma_tic:0.1f} seconds"
|
| 571 |
-
print(chroma_time_out)
|
| 572 |
-
|
| 573 |
-
out_message = "Document processing complete"
|
| 574 |
-
|
| 575 |
-
return out_message, collection
|
| 576 |
-
|
| 577 |
-
def chroma_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
|
| 578 |
-
vec_score_cut_off:float, vec_weight:float, in_join_file = None, in_join_column = None, search_df_join_column = None, embeddings = embeddings_model): # ,vectorstore, embeddings
|
| 579 |
-
|
| 580 |
-
query = embeddings.encode(query_str).tolist()
|
| 581 |
-
|
| 582 |
-
docs = vectorstore.query(
|
| 583 |
-
query_embeddings=query,
|
| 584 |
-
n_results= k_val # No practical limit on number of responses returned
|
| 585 |
-
#where={"metadata_field": "is_equal_to_this"},
|
| 586 |
-
#where_document={"$contains":"search_string"}
|
| 587 |
-
)
|
| 588 |
-
|
| 589 |
-
df_docs = pd.DataFrame(data={'ids': docs['ids'][0],
|
| 590 |
-
'documents': docs['documents'][0],
|
| 591 |
-
'metadatas':docs['metadatas'][0],
|
| 592 |
-
'distances':docs['distances'][0]#,
|
| 593 |
-
#'embeddings': docs['embeddings']
|
| 594 |
-
})
|
| 595 |
-
|
| 596 |
-
results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
|
| 597 |
-
|
| 598 |
-
results_df_name = output_folder + "semantic_search_result.csv"
|
| 599 |
-
results_df_out.to_csv(results_df_name, index= None)
|
| 600 |
-
results_first_text = results_df_out[orig_df_col].iloc[0]
|
| 601 |
-
|
| 602 |
-
return results_first_text, results_df_name
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
from datetime import datetime
|
| 8 |
+
from search_funcs.helper_functions import get_file_path_end, create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
|
| 9 |
+
from torch import cuda, backends
|
|
|
|
|
|
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
+
PandasDataFrame = Type[pd.DataFrame]
|
| 12 |
|
| 13 |
today_rev = datetime.now().strftime("%Y%m%d")
|
| 14 |
|
|
|
|
| 24 |
|
| 25 |
print("Device used is: ", torch_device)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Load embeddings
|
| 28 |
embeddings_name = "BAAI/bge-small-en-v1.5"
|
| 29 |
|
|
|
|
| 48 |
embeddings_model = SentenceTransformer(embeddings_name)
|
| 49 |
print("Could not find local model installation. Downloading from Huggingface")
|
| 50 |
|
| 51 |
+
|
| 52 |
+
def docs_to_bge_embed_np_array(
|
| 53 |
+
docs_out: list,
|
| 54 |
+
in_file: list,
|
| 55 |
+
embeddings_state: np.ndarray,
|
| 56 |
+
output_file_state: str,
|
| 57 |
+
clean: str,
|
| 58 |
+
return_intermediate_files: str = "No",
|
| 59 |
+
embeddings_super_compress: str = "No",
|
| 60 |
+
embeddings_model: SentenceTransformer = embeddings_model,
|
| 61 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 62 |
+
) -> tuple:
|
| 63 |
+
"""
|
| 64 |
+
Process documents to create BGE embeddings and save them as a numpy array.
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
- docs_out (list): List of documents to be embedded.
|
| 68 |
+
- in_file (list): List of input files.
|
| 69 |
+
- embeddings_state (np.ndarray): Current state of embeddings.
|
| 70 |
+
- output_file_state (str): State of the output file.
|
| 71 |
+
- clean (str): Indicates if the data should be cleaned.
|
| 72 |
+
- return_intermediate_files (str, optional): Whether to return intermediate files. Default is "No".
|
| 73 |
+
- embeddings_super_compress (str, optional): Whether to super compress the embeddings. Default is "No".
|
| 74 |
+
- embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
|
| 75 |
+
- progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
- tuple: A tuple containing the output message, embeddings, and output file state.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
|
| 82 |
ensure_output_folder_exists(output_folder)
|
| 83 |
|
| 84 |
if not in_file:
|
| 85 |
out_message = "No input file found. Please load in at least one file."
|
| 86 |
print(out_message)
|
| 87 |
+
return out_message, None, None, output_file_state
|
|
|
|
| 88 |
|
| 89 |
progress(0.6, desc = "Loading/creating embeddings")
|
| 90 |
|
| 91 |
print(f"> Total split documents: {len(docs_out)}")
|
| 92 |
|
|
|
|
|
|
|
| 93 |
page_contents = [doc.page_content for doc in docs_out]
|
| 94 |
|
| 95 |
## Load in pre-embedded file if exists
|
| 96 |
file_list = [string.name for string in in_file]
|
| 97 |
|
|
|
|
|
|
|
| 98 |
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
| 99 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
| 100 |
data_file_name = data_file_names[0]
|
|
|
|
| 102 |
|
| 103 |
out_message = "Document processing complete. Ready to search."
|
| 104 |
|
|
|
|
| 105 |
|
| 106 |
if embeddings_state.size == 0:
|
| 107 |
tic = time.perf_counter()
|
| 108 |
print("Starting to embed documents.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
|
|
|
|
|
|
|
| 111 |
|
| 112 |
toc = time.perf_counter()
|
| 113 |
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
|
|
|
| 141 |
|
| 142 |
return out_message, embeddings_out, output_file_state, output_file_state
|
| 143 |
|
| 144 |
+
def process_data_from_scores_df(
|
| 145 |
+
df_docs: pd.DataFrame,
|
| 146 |
+
in_join_file: pd.DataFrame,
|
| 147 |
+
vec_score_cut_off: float,
|
| 148 |
+
in_join_column: str,
|
| 149 |
+
search_df_join_column: str,
|
| 150 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 151 |
+
) -> pd.DataFrame:
|
| 152 |
+
"""
|
| 153 |
+
Process the data from the scores DataFrame by filtering based on score cutoff and document length,
|
| 154 |
+
and optionally joining with an additional file.
|
| 155 |
+
|
| 156 |
+
Parameters
|
| 157 |
+
----------
|
| 158 |
+
df_docs : pd.DataFrame
|
| 159 |
+
DataFrame containing document scores and metadata.
|
| 160 |
+
in_join_file : pd.DataFrame
|
| 161 |
+
DataFrame to join with the results based on specified columns.
|
| 162 |
+
vec_score_cut_off : float
|
| 163 |
+
Cutoff value for the vector similarity score.
|
| 164 |
+
in_join_column : str
|
| 165 |
+
Column name in the join file to join on.
|
| 166 |
+
search_df_join_column : str
|
| 167 |
+
Column name in the search DataFrame to join on.
|
| 168 |
+
progress : gr.Progress, optional
|
| 169 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
| 170 |
+
|
| 171 |
+
Returns
|
| 172 |
+
-------
|
| 173 |
+
pd.DataFrame
|
| 174 |
+
Processed DataFrame with filtered and joined data.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
docs_scores = df_docs["distances"] #.astype(float)
|
| 178 |
|
| 179 |
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
| 180 |
score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
if score_more_limit.empty:
|
| 183 |
return pd.DataFrame()
|
|
|
|
| 185 |
# Only keep sources that are at least 100 characters long
|
| 186 |
docs_len = score_more_limit["documents"].str.len() >= 100
|
| 187 |
|
|
|
|
|
|
|
| 188 |
length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
if length_more_limit.empty:
|
| 191 |
return pd.DataFrame()
|
| 192 |
|
| 193 |
length_more_limit['ids'] = length_more_limit['ids'].astype(int)
|
| 194 |
|
|
|
|
| 195 |
|
| 196 |
# Explode the 'metadatas' dictionary into separate columns
|
| 197 |
df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
|
| 198 |
|
|
|
|
|
|
|
|
|
|
| 199 |
# Concatenate the original DataFrame with the expanded metadata DataFrame
|
| 200 |
results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
|
| 201 |
|
|
|
|
| 205 |
results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
| 208 |
# Join on additional files
|
| 209 |
if not in_join_file.empty:
|
| 210 |
progress(0.5, desc = "Joining on additional data file")
|
|
|
|
| 221 |
|
| 222 |
return results_df_out
|
| 223 |
|
| 224 |
+
def bge_semantic_search(
|
| 225 |
+
query_str: str,
|
| 226 |
+
embeddings: np.ndarray,
|
| 227 |
+
documents: list,
|
| 228 |
+
k_val: int,
|
| 229 |
+
vec_score_cut_off: float,
|
| 230 |
+
in_join_file: pd.DataFrame,
|
| 231 |
+
in_join_column: str = None,
|
| 232 |
+
search_df_join_column: str = None,
|
| 233 |
+
device: str = torch_device,
|
| 234 |
+
embeddings_model: SentenceTransformer = embeddings_model,
|
| 235 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
| 236 |
+
) -> pd.DataFrame:
|
| 237 |
+
"""
|
| 238 |
+
Perform a semantic search using the BGE model.
|
| 239 |
+
|
| 240 |
+
Parameters:
|
| 241 |
+
- query_str (str): The query string to search for.
|
| 242 |
+
- embeddings (np.ndarray): The embeddings to search within.
|
| 243 |
+
- documents (list): The list of documents to search.
|
| 244 |
+
- k_val (int): The number of top results to return.
|
| 245 |
+
- vec_score_cut_off (float): The score cutoff for filtering results.
|
| 246 |
+
- in_join_file (pd.DataFrame): The DataFrame to join with the search results.
|
| 247 |
+
- in_join_column (str, optional): The column name in the join DataFrame to join on. Default is None.
|
| 248 |
+
- search_df_join_column (str, optional): The column name in the search DataFrame to join on. Default is None.
|
| 249 |
+
- device (str, optional): The device to run the model on. Default is torch_device.
|
| 250 |
+
- embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
|
| 251 |
+
- progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
- pd.DataFrame: The DataFrame containing the search results.
|
| 255 |
+
"""
|
| 256 |
|
|
|
|
| 257 |
progress(0, desc = "Conducting semantic search")
|
| 258 |
|
| 259 |
ensure_output_folder_exists(output_folder)
|
| 260 |
|
| 261 |
print("Searching")
|
| 262 |
|
|
|
|
|
|
|
|
|
|
| 263 |
# Load the sentence transformer model and move it to GPU
|
| 264 |
+
embeddings_model = embeddings_model.to(device)
|
| 265 |
|
| 266 |
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
| 267 |
+
query = embeddings_model.encode(query_str, normalize_embeddings=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Sentence transformers method, not used:
|
| 270 |
+
cosine_similarities = query @ embeddings.T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# Flatten the tensor to a 1D array
|
| 273 |
cosine_similarities = cosine_similarities.flatten()
|
| 274 |
|
|
|
|
|
|
|
|
|
|
| 275 |
# Create a Pandas Series
|
| 276 |
cosine_similarities_series = pd.Series(cosine_similarities)
|
| 277 |
|
| 278 |
+
# Pull out relevent info from documents
|
| 279 |
+
page_contents = [doc.page_content for doc in documents]
|
| 280 |
+
page_meta = [doc.metadata for doc in documents]
|
| 281 |
ids_range = range(0,len(page_contents))
|
| 282 |
ids = [str(element) for element in ids_range]
|
| 283 |
|
| 284 |
+
df_documents = pd.DataFrame(data={"ids": ids,
|
| 285 |
"documents": page_contents,
|
| 286 |
"metadatas":page_meta,
|
| 287 |
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
| 288 |
|
| 289 |
|
| 290 |
+
results_df_out = process_data_from_scores_df(df_documents, in_join_file, vec_score_cut_off, in_join_column, search_df_join_column)
|
| 291 |
|
| 292 |
print("Search complete")
|
| 293 |
|
|
|
|
| 311 |
|
| 312 |
print("Returning results")
|
| 313 |
|
| 314 |
+
return results_first_text, results_df_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search_funcs/spacy_search_funcs.py
CHANGED
|
@@ -27,9 +27,14 @@ except:
|
|
| 27 |
nlp = spacy.load("en_core_web_sm")
|
| 28 |
print("Successfully imported spaCy model")
|
| 29 |
|
| 30 |
-
def spacy_fuzzy_search(string_query:str,
|
| 31 |
''' Conduct fuzzy match on a list of data.'''
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if len(df_list) > 10000:
|
| 34 |
out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
|
| 35 |
return out_message, None
|
|
|
|
| 27 |
nlp = spacy.load("en_core_web_sm")
|
| 28 |
print("Successfully imported spaCy model")
|
| 29 |
|
| 30 |
+
def spacy_fuzzy_search(string_query:str, tokenised_data: List[List[str]], original_data: PandasDataFrame, text_column:str, in_join_file: PandasDataFrame, search_df_join_column:str, in_join_column:str, no_spelling_mistakes:int = 1, progress=gr.Progress(track_tqdm=True)):
|
| 31 |
''' Conduct fuzzy match on a list of data.'''
|
| 32 |
|
| 33 |
+
#print("df_list:", df_list)
|
| 34 |
+
|
| 35 |
+
# Convert tokenised data back into a list of strings
|
| 36 |
+
df_list = list(map(" ".join, tokenised_data))
|
| 37 |
+
|
| 38 |
if len(df_list) > 10000:
|
| 39 |
out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
|
| 40 |
return out_message, None
|