Spaces:
Sleeping
Sleeping
# MIT License | |
# | |
# Copyright (c) 2023 Victor Calderon | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import logging | |
from typing import Dict, Optional | |
import numpy as np | |
import pandas as pd | |
import torch | |
from datasets import Dataset | |
from sentence_transformers import SentenceTransformer | |
from src.utils import default_variables as dv | |
__author__ = ["Victor Calderon"] | |
__copyright__ = ["Copyright 2023 Victor Calderon"] | |
__all__ = ["SemanticSearchEngine"] | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s]: %(message)s", | |
) | |
logger.setLevel(logging.INFO) | |
# --------------------------- CLASS DEFINITIONS ------------------------------- | |
class SemanticSearchEngine(object): | |
""" | |
Class object for running Semantic Search on the input dataset. | |
""" | |
def __init__(self, **kwargs): | |
""" | |
Class object for running Semantic Search on the input dataset. | |
""" | |
# --- Defining variables | |
# Device to use, i.e. CPU or GPU | |
self.device = self._get_device() | |
# Embedder model to use | |
self.model = "paraphrase-mpnet-base-v2" | |
# Defining the embedder | |
self.embedder = self._get_embedder() | |
# Corpus embeddings | |
self.source_colname = kwargs.get( | |
"source_colname", | |
"summary", | |
) | |
self.embeddings_colname = kwargs.get( | |
"embeddings_colname", | |
dv.embeddings_colname, | |
) | |
# Variables used for running semantic search | |
self.corpus_dataset_with_faiss_index = kwargs.get( | |
"corpus_dataset_with_faiss_index" | |
) | |
def _get_device(self) -> str: | |
""" | |
Method for determining the device to use. | |
Returns | |
---------- | |
device_type : str | |
Type of device to use (e.g. 'cpu' or 'cuda'). | |
Options: | |
- ``cpu`` : Uses a CPU. | |
- ``cuda`` : Uses a GPU. | |
""" | |
# Determining the type of device to use | |
device_type = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f">> Running on a '{device_type.upper()}' device") | |
return device_type | |
def _get_embedder(self): | |
""" | |
Method for extracting the Embedder model. | |
Returns | |
--------- | |
embedder : model | |
Variable corresponding to the Embeddings models. | |
""" | |
embedder = SentenceTransformer(self.model) | |
embedder.to(self.device) | |
return embedder | |
def generate_corpus_index_and_embeddings( | |
self, | |
corpus_dataset: Dataset, | |
) -> Dataset: | |
""" | |
Method for generating the Text Embeddings and FAISS indices from | |
the input dataset. | |
Parameters | |
------------ | |
corpus_dataset : datasets.Dataset | |
Dataset containing the text to use to create the text | |
embeddings and FAISS indices. | |
Returns | |
---------- | |
corpus_dataset_with_embeddings : datasets.Dataset | |
Dataset containing the original data rom ``corpus_dataset`` | |
plus the corresponding text embeddings of the ``source_colname`` | |
column. | |
""" | |
torch.set_grad_enabled(False) | |
# --- Generate text embeddings for the source column | |
corpus_dataset_with_embeddings = corpus_dataset.map( | |
lambda corpus: { | |
self.embeddings_colname: self.embedder.encode( | |
corpus[self.source_colname] | |
) | |
}, | |
batched=True, | |
desc="Computing Semantic Search Embeddings", | |
) | |
# --- Adding FAISS index | |
corpus_dataset_with_embeddings.add_faiss_index( | |
column=self.embeddings_colname, | |
faiss_verbose=True, | |
device=None if self.device == "cpu" else 1, | |
) | |
return corpus_dataset_with_embeddings | |
def run_semantic_search( | |
self, | |
query: str, | |
top_n: Optional[int] = 5, | |
) -> Dict: # sourcery skip: extract-duplicate-method | |
""" | |
Method for running a semantic search on a query after having | |
created the corpus of the text embeddings. | |
Parameters | |
-------------- | |
query : str | |
Text query to use for searching the database. | |
top_n : int, optional | |
Variable corresponding to the 'Top N' values to return based on the | |
similarity score between the input query and the corpus. This | |
variable is set to ``10`` by default. | |
Returns | |
--------- | |
match_results : dict | |
Dictionary containing the metadata of each of the articles | |
that were in the Top-N in terms of being most similar to the | |
input query ``query``. | |
""" | |
# --- Checking input parameters | |
# 'query' - Type | |
query_type_arr = (str,) | |
if not isinstance(query, query_type_arr): | |
msg = ">> 'query' ({}) is not a valid input type ({})".format( | |
type(query), query_type_arr | |
) | |
logger.error(msg) | |
raise TypeError(msg) | |
# 'top_n' - Type | |
top_n_type_arr = (int,) | |
if not isinstance(top_n, top_n_type_arr): | |
msg = ">> 'top_n' ({}) is not a valid input type ({})".format( | |
type(top_n), top_n_type_arr | |
) | |
logger.error(msg) | |
raise TypeError(msg) | |
# 'top_n' - Value | |
if top_n <= 0: | |
msg = f">> 'top_n' ({top_n}) must be larger than '0'!" | |
logger.error(msg) | |
raise ValueError(msg) | |
# --- Checking that the encoder has been indexed correctly | |
if self.corpus_dataset_with_faiss_index is None: | |
msg = ">>> The FAISS index was not properly set!" | |
logger.error(msg) | |
raise ValueError(msg) | |
# --- Encode the input query and extract the embedding | |
query_embedding = self.embedder.encode(query) | |
# --- Extracting the top-N results | |
( | |
scores, | |
results, | |
) = self.corpus_dataset_with_faiss_index.get_nearest_examples( | |
self.embeddings_colname, | |
query_embedding, | |
k=top_n, | |
) | |
# --- Sorting from highest to lowest | |
# NOTE: We need to deconstruct the 'results' to be able to organize | |
# the results | |
parsed_results = pd.DataFrame.from_dict( | |
data=results, | |
orient="columns", | |
) | |
parsed_results.loc[:, "relevance"] = scores | |
# Sorting in descending order | |
parsed_results = parsed_results.sort_values( | |
by=["relevance"], | |
ascending=False, | |
).reset_index(drop=True) | |
# Casting data type for the 'relevance' | |
parsed_results.loc[:, "relevance"] = parsed_results["relevance"].apply( | |
lambda x: str(np.round(x, 5)) | |
) | |
# Only keeping certain columns | |
columns_to_keep = ["_id", "title", "relevance", "content"] | |
return parsed_results[columns_to_keep].to_dict(orient="index") | |