File size: 3,646 Bytes
1c4216d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import logging
from typing import Any
import numpy as np
from rag_utils.rag_utils import OpenAI, do_1_embed
logger = logging.getLogger(__name__)
def do_sort(
embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str]
) -> list[dict[str, str | float]]:
"""
Sort documents based on their cosine similarity to the query embedding.
Args:
embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings.
arr_q (np.ndarray): Query embedding.
Returns:
pd.DataFrame: Sorted dataframe containing document IDs and similarity scores.
"""
# Calculate cosine similarities between query embedding and document embeddings
cos_sims = np.dot(embed_talks, embed_q)
# Get the indices of the best matching video IDs
best_match_video_ids = np.argsort(-cos_sims)
# Get the sorted video IDs based on the best match indices
sorted_vids = [
{"id0": list_talk_ids[i], "score": -cs}
for i, cs in zip(best_match_video_ids, np.sort(-cos_sims))
]
return sorted_vids
def limit_docs(
sorted_vids: list[dict[str, str | float]], talk_info: dict[str, str | int], n_results: int
) -> list[dict[str, Any]]:
"""
Limit the retrieved documents based on a score threshold and return the top documents.
Args:
df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores.
df_talks (pd.DataFrame): Dataframe containing talk information.
n_results (int): Number of top documents to retrieve.
transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID.
Returns:
dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text.
"""
# Get the top n_results documents
top_vids = sorted_vids[:n_results]
# Get the top score and calculate the score threshold
top_score = top_vids[0]["score"]
score_thresh = max(min(0.6, top_score - 0.05), 0.2)
# Filter the top documents based on the score threshold
keep_texts = []
for my_vid in top_vids:
if my_vid["score"] >= score_thresh:
vid_data = talk_info[my_vid["id0"]]
vid_data = {**vid_data, **my_vid}
keep_texts.append(vid_data)
logger.info(f"{len(keep_texts)} videos kept")
return keep_texts
def do_retrieval(
query0: str,
n_results: int,
api_client: OpenAI,
talk_ids: list[str],
embeds: np.ndarray,
talk_info: dict[str, str | int],
) -> list[dict[str, Any]]:
"""
Retrieve relevant documents based on the user's query.
Args:
query0 (str): The user's query.
n_results (int): The number of documents to retrieve.
api_client (OpenAI): The API client (OpenAI) for generating embeddings.
Returns:
dict[str, dict]: The retrieved documents.
"""
logger.info(f"Starting document retrieval for query: {query0}")
try:
# Generate embeddings for the query
arr_q = do_1_embed(query0, api_client)
# Sort documents based on their cosine similarity to the query embedding
sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids)
# Limit the retrieved documents based on a score threshold
keep_texts = limit_docs(sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results)
logger.info(f"Retrieved {len(keep_texts)} documents for query: {query0}")
return keep_texts
except Exception as e:
logger.error(f"Error during document retrieval for query: {query0}, Error: {str(e)}")
raise e
|