File size: 7,512 Bytes
c89fdea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Type

import pandas as pd
import pymongo
from fastapi.encoders import jsonable_encoder
from pyparsing import Optional

from buster.completers import UserInputs
from buster.completers.base import Completion

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class StandardForm:
    def to_json(self) -> Any:
        return jsonable_encoder(self)

    @classmethod
    def from_dict(cls, interaction_dict: dict) -> StandardForm:
        return cls(**interaction_dict)


@dataclass
class FeedbackForm(StandardForm):
    """Form on the original Buster app."""

    # Overall experience
    overall_experience: str

    # Answer Quality
    clear_answer: str
    accurate_answer: str

    # Source Relevance
    relevant_sources: str
    relevant_sources_order: str
    relevant_sources_selection: list

    # beginner, intermediate, expert at AI policy?
    expertise: list[str]

    # Additional Feedback
    extra_info: str


@dataclass
class ComparisonForm(StandardForm):
    """Easily readable comparison result on the battle arena."""

    question: str
    model_left: str
    model_right: str
    vote: str
    extra_info: str


@dataclass
class Interaction:
    user_completions: list[Completion]
    time: str
    session_id: str  # A unique identifier for each gradio session, e.g. UUID
    username: Optional[str] = None
    instance_type: Optional[str] = None  # Dev or prod
    instance_name: Optional[str] = None  #  Heroku, hf-space, etc.
    data_version: Optional[str] = None  # Which collection of the was used
    form: Optional[StandardForm] = None

    def send(self, mongo_db: pymongo.database.Database, collection: str):
        feedback_json = self.to_json()
        logger.info(feedback_json)

        try:
            mongo_db[collection].insert_one(feedback_json)
            logger.info(f"response logged to mondogb {collection=}")
        except Exception as err:
            logger.exception(f"Something went wrong logging to mongodb {collection=}")
            raise err

    def flatten(self) -> dict:
        """Flattens the Interaction object into a dict for easier reading."""
        interaction_dict = self.to_json()

        # Flatten user completions, only keep the most recent interaction
        if len(interaction_dict["user_completions"]) > 0:
            completion_dict = interaction_dict["user_completions"][-1]
            # # TODO: add test for this...
            for k in completion_dict.keys():
                interaction_dict[f"completion_{k}"] = completion_dict[k]
        del interaction_dict["user_completions"]

        if self.form is not None:
            # Flatten feedback form
            for k in interaction_dict["form"].keys():
                interaction_dict[f"form_{k}"] = interaction_dict["form"][k]
            del interaction_dict["form"]

        # Flatten matched documents
        interaction_dict["matched_documents"] = self.user_completions[-1].matched_documents
        interaction_dict["matched_documents"].reset_index(inplace=True)
        interaction_dict["matched_documents"].drop(columns=["index"], inplace=True)
        interaction_dict["matched_documents"] = interaction_dict["matched_documents"].T
        if len(interaction_dict["matched_documents"]) > 0:
            for k in interaction_dict["matched_documents"].keys():
                interaction_dict[f"matched_documents_{k}"] = interaction_dict["matched_documents"][k].values
        del interaction_dict["matched_documents"]

        return interaction_dict

    def to_json(self) -> Any:
        custom_encoder = {
            # Converts the matched_documents in the user_completions to json
            Completion: lambda completion: completion.to_json(columns_to_ignore=["embedding", "_id"]),
        }

        to_encode = {
            "username": self.username,
            "session_id": self.session_id,
            "user_completions": self.user_completions,
            "time": self.time,
            "instance_type": self.instance_type,
            "instance_name": self.instance_name,
            "data_version": self.data_version,
        }

        if self.form is not None:
            to_encode["form"] = self.form.to_json()

        return jsonable_encoder(to_encode, custom_encoder=custom_encoder)

    @classmethod
    def from_dict(cls, interaction_dict: dict, feedback_cls: Optional[Type[StandardForm]] = None) -> Interaction:
        # remove the _id from mongodb
        if "_id" in interaction_dict.keys():
            del interaction_dict["_id"]

        interaction_dict["user_completions"] = [Completion.from_dict(r) for r in interaction_dict["user_completions"]]

        if "form" in interaction_dict.keys():
            # The interaction contained a type of form, e.g. feedback form, parse it accordingly

            # Make sure the user specified a feedback_cls
            assert feedback_cls is not None, "You must specify which type of feedback it is"

            interaction_dict["form"] = feedback_cls.from_dict(interaction_dict["form"])

        return cls(**interaction_dict)


def read_collection(
    mongo_db: pymongo.database.Database,
    collection: str,
    feedback_cls: Optional[Type[StandardForm]] = None,
    filters: Optional[dict] = None,
) -> pd.DataFrame:
    """
    Retrieve data from a MongoDB collection and return it as a pandas DataFrame.

    Parameters:
    - mongo_db (pymongo.database.Database): The MongoDB database instance.
    - collection (str): The name of the MongoDB collection to read from.
    - feedback_cls (Optional[Type[StandardForm]]): A class to which the retrieved data might be mapped.
      If the collection contains instances of Interaction, this is not needed. If a form is attached
      (i.e., interaction["form"] exists), it should be provided.
    - filters (Optional[dict]): A dictionary of filters to apply to the mongodb query. If not provided,
      all items in the collection are returned. E.g., to get interactions from a specific user,
      use `filters={"username": <username>}`.

    Returns:
    - pd.DataFrame: A DataFrame containing the retrieved data. Data is flattened for convenience.

    Notes:
    - Interactions that cannot be processed are skipped, and a log message is generated with the
      count of retrieved and skipped entries.
    """
    flattened_interactions = []
    skipped_interactions = []
    interactions = mongo_db[collection].find(filters)
    for interaction in interactions:
        try:
            if user_input := interaction["user_completions"][0].get("user_input"):
                # We used to only have a single key for user input
                # This changed when we introduced question reformulation.
                # Only useful to maintain backwards compatibility with data collected previously
                interaction["user_completions"][0]["user_inputs"] = UserInputs(user_input)
                del interaction["user_completions"][0]["user_input"]

            flattened_interaction = Interaction.from_dict(interaction, feedback_cls=feedback_cls).flatten()
            flattened_interactions.append(flattened_interaction)
        except Exception as err:
            skipped_interactions.append(interaction)

    logger.info(f"Retrieved {len(flattened_interactions)} entries. Skipped {len(skipped_interactions)} entries")

    return pd.DataFrame(flattened_interactions)