Cicero_Synthesizer_Space / src /classes /semantic_search_engine.py
caldervf's picture
Adding files from Github repository.
74c716c
# MIT License
#
# Copyright (c) 2023 Victor Calderon
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import logging
from typing import Dict, Optional
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from src.utils import default_variables as dv
__author__ = ["Victor Calderon"]
__copyright__ = ["Copyright 2023 Victor Calderon"]
__all__ = ["SemanticSearchEngine"]
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s]: %(message)s",
)
logger.setLevel(logging.INFO)
# --------------------------- CLASS DEFINITIONS -------------------------------
class SemanticSearchEngine(object):
"""
Class object for running Semantic Search on the input dataset.
"""
def __init__(self, **kwargs):
"""
Class object for running Semantic Search on the input dataset.
"""
# --- Defining variables
# Device to use, i.e. CPU or GPU
self.device = self._get_device()
# Embedder model to use
self.model = "paraphrase-mpnet-base-v2"
# Defining the embedder
self.embedder = self._get_embedder()
# Corpus embeddings
self.source_colname = kwargs.get(
"source_colname",
"summary",
)
self.embeddings_colname = kwargs.get(
"embeddings_colname",
dv.embeddings_colname,
)
# Variables used for running semantic search
self.corpus_dataset_with_faiss_index = kwargs.get(
"corpus_dataset_with_faiss_index"
)
def _get_device(self) -> str:
"""
Method for determining the device to use.
Returns
----------
device_type : str
Type of device to use (e.g. 'cpu' or 'cuda').
Options:
- ``cpu`` : Uses a CPU.
- ``cuda`` : Uses a GPU.
"""
# Determining the type of device to use
device_type = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f">> Running on a '{device_type.upper()}' device")
return device_type
def _get_embedder(self):
"""
Method for extracting the Embedder model.
Returns
---------
embedder : model
Variable corresponding to the Embeddings models.
"""
embedder = SentenceTransformer(self.model)
embedder.to(self.device)
return embedder
def generate_corpus_index_and_embeddings(
self,
corpus_dataset: Dataset,
) -> Dataset:
"""
Method for generating the Text Embeddings and FAISS indices from
the input dataset.
Parameters
------------
corpus_dataset : datasets.Dataset
Dataset containing the text to use to create the text
embeddings and FAISS indices.
Returns
----------
corpus_dataset_with_embeddings : datasets.Dataset
Dataset containing the original data rom ``corpus_dataset``
plus the corresponding text embeddings of the ``source_colname``
column.
"""
torch.set_grad_enabled(False)
# --- Generate text embeddings for the source column
corpus_dataset_with_embeddings = corpus_dataset.map(
lambda corpus: {
self.embeddings_colname: self.embedder.encode(
corpus[self.source_colname]
)
},
batched=True,
desc="Computing Semantic Search Embeddings",
)
# --- Adding FAISS index
corpus_dataset_with_embeddings.add_faiss_index(
column=self.embeddings_colname,
faiss_verbose=True,
device=None if self.device == "cpu" else 1,
)
return corpus_dataset_with_embeddings
def run_semantic_search(
self,
query: str,
top_n: Optional[int] = 5,
) -> Dict: # sourcery skip: extract-duplicate-method
"""
Method for running a semantic search on a query after having
created the corpus of the text embeddings.
Parameters
--------------
query : str
Text query to use for searching the database.
top_n : int, optional
Variable corresponding to the 'Top N' values to return based on the
similarity score between the input query and the corpus. This
variable is set to ``10`` by default.
Returns
---------
match_results : dict
Dictionary containing the metadata of each of the articles
that were in the Top-N in terms of being most similar to the
input query ``query``.
"""
# --- Checking input parameters
# 'query' - Type
query_type_arr = (str,)
if not isinstance(query, query_type_arr):
msg = ">> 'query' ({}) is not a valid input type ({})".format(
type(query), query_type_arr
)
logger.error(msg)
raise TypeError(msg)
# 'top_n' - Type
top_n_type_arr = (int,)
if not isinstance(top_n, top_n_type_arr):
msg = ">> 'top_n' ({}) is not a valid input type ({})".format(
type(top_n), top_n_type_arr
)
logger.error(msg)
raise TypeError(msg)
# 'top_n' - Value
if top_n <= 0:
msg = f">> 'top_n' ({top_n}) must be larger than '0'!"
logger.error(msg)
raise ValueError(msg)
# --- Checking that the encoder has been indexed correctly
if self.corpus_dataset_with_faiss_index is None:
msg = ">>> The FAISS index was not properly set!"
logger.error(msg)
raise ValueError(msg)
# --- Encode the input query and extract the embedding
query_embedding = self.embedder.encode(query)
# --- Extracting the top-N results
(
scores,
results,
) = self.corpus_dataset_with_faiss_index.get_nearest_examples(
self.embeddings_colname,
query_embedding,
k=top_n,
)
# --- Sorting from highest to lowest
# NOTE: We need to deconstruct the 'results' to be able to organize
# the results
parsed_results = pd.DataFrame.from_dict(
data=results,
orient="columns",
)
parsed_results.loc[:, "relevance"] = scores
# Sorting in descending order
parsed_results = parsed_results.sort_values(
by=["relevance"],
ascending=False,
).reset_index(drop=True)
# Casting data type for the 'relevance'
parsed_results.loc[:, "relevance"] = parsed_results["relevance"].apply(
lambda x: str(np.round(x, 5))
)
# Only keeping certain columns
columns_to_keep = ["_id", "title", "relevance", "content"]
return parsed_results[columns_to_keep].to_dict(orient="index")