File size: 3,646 Bytes
1c4216d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
from typing import Any

import numpy as np

from rag_utils.rag_utils import OpenAI, do_1_embed

logger = logging.getLogger(__name__)


def do_sort(
    embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str]
) -> list[dict[str, str | float]]:
    """
    Sort documents based on their cosine similarity to the query embedding.

    Args:
        embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings.
        arr_q (np.ndarray): Query embedding.

    Returns:
        pd.DataFrame: Sorted dataframe containing document IDs and similarity scores.
    """

    # Calculate cosine similarities between query embedding and document embeddings
    cos_sims = np.dot(embed_talks, embed_q)

    # Get the indices of the best matching video IDs
    best_match_video_ids = np.argsort(-cos_sims)

    # Get the sorted video IDs based on the best match indices
    sorted_vids = [
        {"id0": list_talk_ids[i], "score": -cs}
        for i, cs in zip(best_match_video_ids, np.sort(-cos_sims))
    ]

    return sorted_vids


def limit_docs(
    sorted_vids: list[dict[str, str | float]], talk_info: dict[str, str | int], n_results: int
) -> list[dict[str, Any]]:
    """
    Limit the retrieved documents based on a score threshold and return the top documents.

    Args:
        df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores.
        df_talks (pd.DataFrame): Dataframe containing talk information.
        n_results (int): Number of top documents to retrieve.
        transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID.

    Returns:
        dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text.
    """

    # Get the top n_results documents
    top_vids = sorted_vids[:n_results]

    # Get the top score and calculate the score threshold
    top_score = top_vids[0]["score"]
    score_thresh = max(min(0.6, top_score - 0.05), 0.2)

    # Filter the top documents based on the score threshold
    keep_texts = []
    for my_vid in top_vids:
        if my_vid["score"] >= score_thresh:
            vid_data = talk_info[my_vid["id0"]]
            vid_data = {**vid_data, **my_vid}
            keep_texts.append(vid_data)

    logger.info(f"{len(keep_texts)} videos kept")

    return keep_texts


def do_retrieval(
    query0: str,
    n_results: int,
    api_client: OpenAI,
    talk_ids: list[str],
    embeds: np.ndarray,
    talk_info: dict[str, str | int],
) -> list[dict[str, Any]]:
    """
    Retrieve relevant documents based on the user's query.

    Args:
        query0 (str): The user's query.
        n_results (int): The number of documents to retrieve.
        api_client (OpenAI): The API client (OpenAI) for generating embeddings.

    Returns:
        dict[str, dict]: The retrieved documents.
    """
    logger.info(f"Starting document retrieval for query: {query0}")
    try:
        # Generate embeddings for the query
        arr_q = do_1_embed(query0, api_client)

        # Sort documents based on their cosine similarity to the query embedding
        sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids)

        # Limit the retrieved documents based on a score threshold
        keep_texts = limit_docs(sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results)

        logger.info(f"Retrieved {len(keep_texts)} documents for query: {query0}")

        return keep_texts
    except Exception as e:
        logger.error(f"Error during document retrieval for query: {query0}, Error: {str(e)}")
        raise e