Spaces:
Sleeping
Sleeping
# MIT License | |
# | |
# Copyright (c) 2023 Victor Calderon | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import logging | |
from typing import Dict, Optional | |
from datasets import Dataset | |
from fastapi import Depends, FastAPI | |
from fastapi.responses import RedirectResponse | |
from huggingface_hub import hf_hub_download | |
from pydantic import BaseModel | |
from src.classes import hugging_face_utils as hf | |
from src.classes import semantic_search_engine as ss | |
from src.utils import default_variables as dv | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s]: %(message)s", | |
) | |
logger.setLevel(logging.INFO) | |
# ------------------------------- VARIABLES ----------------------------------- | |
APP_TITLE = "Cicero LLM Synthesizer" | |
APP_DESCRIPTION = f""" | |
The '{APP_TITLE}'is an app that will identify the top-N articles from the | |
Cicero database that are most similar to the user's input query. | |
""" | |
APP_VERSION = "0.1" | |
# ----------------------------- APP-SPECIFIC ---------------------------------- | |
# Defining the appliation value | |
app = FastAPI( | |
title=APP_TITLE, | |
description=APP_DESCRIPTION, | |
version=APP_VERSION, | |
) | |
# -------------------------------- CLASSES ------------------------------------ | |
class QueryParams(BaseModel): | |
input_query: str | |
number_articles: Optional[int] = 5 | |
# ------------------------------- FUNCTIONS ----------------------------------- | |
def download_dataset_and_faiss_index() -> Dataset: | |
""" | |
Function to download the corresponding dataset and the FAISS index | |
from HuggingFace. | |
Returns | |
------------- | |
dataset_with_faiss_index : datasets.Dataset | |
Dataset from HuggingFace with the FAISS index loaded. | |
""" | |
# --- Initializing HuggingFace API | |
# Object for interacting with HuggingFace | |
hf_obj = hf.HuggingFaceHelper() | |
# Defining variable names for each of the objects | |
faiss_index_name = f"{dv.faiss_index_name}.faiss" | |
dataset_name = dv.dataset_faiss_embeddings_name | |
username = hf_obj.username | |
repository_name = dv.hugging_face_repository_name | |
repository_id = f"{username}/{repository_name}" | |
repository_type = "dataset" | |
split_type = "train" | |
# --- Downloading FAISS Index | |
faiss_index_local_path = hf_hub_download( | |
repo_id=repository_id, | |
filename=faiss_index_name, | |
repo_type=repository_type, | |
token=hf_obj.api.token, | |
) | |
# --- Downloading Dataset | |
dataset_obj = hf_obj.get_dataset_from_hub( | |
dataset_name=dataset_name, | |
username=username, | |
split=split_type, | |
) | |
# --- Adding FAISS index to the dataset | |
dataset_obj.load_faiss_index( | |
index_name=dv.embeddings_colname, | |
file=faiss_index_local_path, | |
) | |
return dataset_obj | |
def run_semantic_search_task(query: str, number_articles: int) -> Dict: | |
""" | |
Function to run semantic search on an input query. It will return a | |
set of 'Top-N' articles that are most similar to the input query. | |
Parameters | |
------------ | |
query : str | |
Input query to use when running the Semantic Search Engine. | |
number_articles : int | |
Number of articles to return from the Semantic Search. | |
Returns | |
---------- | |
ranked_results : dict | |
Dictionary containing the ranked results from the Semantic | |
Search Engine. | |
""" | |
# --- Extracting dataset with FAISS index | |
corpus_dataset_with_faiss_index = download_dataset_and_faiss_index() | |
# --- Initializing Semantic Search Engine | |
semantic_search_obj = ss.SemanticSearchEngine( | |
corpus_dataset_with_faiss_index=corpus_dataset_with_faiss_index | |
) | |
# --- Running search on Top-N results | |
return semantic_search_obj.run_semantic_search( | |
query=query, | |
top_n=number_articles, | |
) | |
# -------------------------------- ROUTES ------------------------------------- | |
async def docs_redirect(): | |
return RedirectResponse(url="/docs") | |
# ---- Semantic Search | |
async def run_semantic_search(query_params: QueryParams = Depends()): | |
""" | |
Function to run semantic search on the an input query. | |
Parameters | |
-------------- | |
query : str | |
Input query to use when running the Semantic Search Engine. | |
number_articles : int | |
Number of articles to return from the Semantic Search. | |
""" | |
return run_semantic_search_task( | |
query=query_params.input_query, | |
number_articles=query_params.number_articles, | |
) | |