update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
Browse files- src/analysis/combine_job_returns.py +207 -21
- src/analysis/common.py +27 -6
- src/analysis/compare_job_returns.py +1 -1
- src/analysis/format_metric_results.py +269 -0
- src/analysis/get_json_field_as_string.py +55 -0
- src/analysis/show_inference_params_on_quality_and_throughput.py +485 -0
- src/datamodules/__init__.py +1 -1
- src/datamodules/datamodule_with_sampler.py +59 -0
- src/dataset/processing.py +88 -3
- src/demo/annotation_utils.py +6 -56
- src/demo/backend_utils.py +50 -12
- src/demo/retrieve_and_dump_all_relevant.py +82 -38
- src/document/processing.py +300 -1
- src/evaluate.py +3 -3
- src/evaluate_documents.py +1 -1
- src/hydra_callbacks/save_job_return_value.py +67 -4
- src/langchain_modules/basic_pie_document_store.py +3 -1
- src/langchain_modules/datasets_pie_document_store.py +1 -1
- src/metrics/__init__.py +7 -1
- src/metrics/connected_component_sizes.py +43 -0
- src/metrics/coref.py +223 -0
- src/metrics/coref_sklearn.py +158 -43
- src/metrics/f1_with_bootstrapping.py +103 -0
- src/metrics/f1_with_threshold.py +33 -0
- src/metrics/ranking_sklearn.py +193 -0
- src/metrics/score_distribution.py +13 -4
- src/metrics/semantically_same_ranking.py +448 -0
- src/metrics/tpfpfn.py +193 -0
- src/models/__init__.py +2 -1
- src/models/sequence_classification_with_pooler.py +65 -30
- src/predict.py +2 -2
- src/serializer/__init__.py +4 -1
- src/serializer/interface.py +1 -2
- src/serializer/json.py +7 -121
- src/start_demo.py +3 -2
- src/train.py +2 -3
- src/utils/graph_utils.py +47 -0
- src/utils/inference_utils.py +4 -1
- src/utils/pdf_utils/process_pdf.py +1 -1
src/analysis/combine_job_returns.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import pyrootutils
|
|
|
2 |
|
3 |
root = pyrootutils.setup_root(
|
4 |
search_from=__file__,
|
@@ -6,14 +7,18 @@ root = pyrootutils.setup_root(
|
|
6 |
pythonpath=True,
|
7 |
dotenv=False,
|
8 |
)
|
9 |
-
|
10 |
import argparse
|
|
|
11 |
import os
|
|
|
12 |
|
|
|
13 |
import pandas as pd
|
14 |
|
15 |
from src.analysis.common import read_nested_jsons
|
16 |
|
|
|
|
|
17 |
|
18 |
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
|
19 |
parts = path_and_maybe_id.split(separator, 1)
|
@@ -22,7 +27,7 @@ def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[
|
|
22 |
return parts[0], parts[1]
|
23 |
|
24 |
|
25 |
-
def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) ->
|
26 |
with open(paths_file, "r") as f:
|
27 |
paths_maybe_with_ids = f.readlines()
|
28 |
ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
|
@@ -31,10 +36,40 @@ def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> dic
|
|
31 |
file_base_name, ext = os.path.splitext(file_name)
|
32 |
file_name = f"{file_base_name}.aggregated{ext}"
|
33 |
file_paths = [os.path.join(path, file_name) for path in paths]
|
34 |
-
return
|
35 |
-
id if id is not None else f"idx={idx}"
|
36 |
for idx, (id, path) in enumerate(zip(ids, file_paths))
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def main(
|
@@ -46,17 +81,36 @@ def main(
|
|
46 |
format: str,
|
47 |
transpose: bool = False,
|
48 |
unpack_multirun_results: bool = False,
|
|
|
49 |
in_percent: bool = False,
|
50 |
reset_index: bool = False,
|
|
|
|
|
|
|
|
|
|
|
51 |
):
|
52 |
file_paths = get_file_paths(
|
53 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
54 |
)
|
55 |
data = read_nested_jsons(json_paths=file_paths)
|
56 |
|
|
|
|
|
57 |
if columns is not None:
|
58 |
-
columns_multi_index = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
try:
|
|
|
|
|
|
|
|
|
60 |
data_series = [data[col] for col in columns_multi_index]
|
61 |
except KeyError as e:
|
62 |
print(
|
@@ -88,35 +142,129 @@ def main(
|
|
88 |
for level in sorted(unique_column_levels, reverse=True):
|
89 |
data.columns = data.columns.droplevel(level)
|
90 |
|
91 |
-
if unpack_multirun_results:
|
92 |
index_names = list(data.index.names)
|
93 |
-
data_series_lists = data.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# needs to happen before rounding, otherwise the rounding will be off
|
102 |
if in_percent:
|
103 |
-
|
|
|
104 |
|
105 |
if round_precision is not None:
|
106 |
data = data.round(round_precision)
|
107 |
|
108 |
# needs to happen before transposing
|
109 |
if format == "markdown_mean_and_std":
|
110 |
-
if
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if transpose:
|
118 |
data = data.T
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
if reset_index:
|
121 |
data = data.reset_index()
|
122 |
|
@@ -148,7 +296,17 @@ if __name__ == "__main__":
|
|
148 |
parser.add_argument(
|
149 |
"--unpack-multirun-results", action="store_true", help="Unpack multirun results"
|
150 |
)
|
|
|
|
|
|
|
|
|
|
|
151 |
parser.add_argument("--transpose", action="store_true", help="Transpose the table")
|
|
|
|
|
|
|
|
|
|
|
152 |
parser.add_argument(
|
153 |
"--round-precision",
|
154 |
type=int,
|
@@ -160,6 +318,34 @@ if __name__ == "__main__":
|
|
160 |
parser.add_argument(
|
161 |
"--reset-index", action="store_true", help="Reset the index of the combined job returns"
|
162 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
parser.add_argument(
|
164 |
"--format",
|
165 |
type=str,
|
|
|
1 |
import pyrootutils
|
2 |
+
from pandas import MultiIndex
|
3 |
|
4 |
root = pyrootutils.setup_root(
|
5 |
search_from=__file__,
|
|
|
7 |
pythonpath=True,
|
8 |
dotenv=False,
|
9 |
)
|
|
|
10 |
import argparse
|
11 |
+
import logging
|
12 |
import os
|
13 |
+
from typing import Iterable, List, Optional, Tuple
|
14 |
|
15 |
+
import numpy as np
|
16 |
import pandas as pd
|
17 |
|
18 |
from src.analysis.common import read_nested_jsons
|
19 |
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
|
23 |
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
|
24 |
parts = path_and_maybe_id.split(separator, 1)
|
|
|
27 |
return parts[0], parts[1]
|
28 |
|
29 |
|
30 |
+
def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> List[Tuple[str, str]]:
|
31 |
with open(paths_file, "r") as f:
|
32 |
paths_maybe_with_ids = f.readlines()
|
33 |
ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
|
|
|
36 |
file_base_name, ext = os.path.splitext(file_name)
|
37 |
file_name = f"{file_base_name}.aggregated{ext}"
|
38 |
file_paths = [os.path.join(path, file_name) for path in paths]
|
39 |
+
return [
|
40 |
+
(id if id is not None else f"idx={idx}", path)
|
41 |
for idx, (id, path) in enumerate(zip(ids, file_paths))
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def get_job_id_col(index: pd.MultiIndex) -> Optional[Tuple]:
|
46 |
+
for idx in index:
|
47 |
+
if "job_id" in idx:
|
48 |
+
return idx
|
49 |
+
return None
|
50 |
+
|
51 |
+
|
52 |
+
def stringify(value: str | int | float | None | tuple | list) -> str:
|
53 |
+
if isinstance(value, str):
|
54 |
+
return value
|
55 |
+
if value is None:
|
56 |
+
return ""
|
57 |
+
if isinstance(value, float) and np.isnan(value):
|
58 |
+
return ""
|
59 |
+
if isinstance(value, (int, float)):
|
60 |
+
return str(value)
|
61 |
+
if isinstance(value, Iterable):
|
62 |
+
entries = [stringify(v) for v in value]
|
63 |
+
return "/".join(v for v in entries if v)
|
64 |
+
return value
|
65 |
+
|
66 |
+
|
67 |
+
def remove_part_from_multi_index(index: pd.MultiIndex, part: str) -> pd.MultiIndex:
|
68 |
+
new_index = []
|
69 |
+
for idx in index:
|
70 |
+
new_idx = tuple([i for i in idx if i != part])
|
71 |
+
new_index.append(new_idx)
|
72 |
+
return MultiIndex.from_tuples(new_index)
|
73 |
|
74 |
|
75 |
def main(
|
|
|
81 |
format: str,
|
82 |
transpose: bool = False,
|
83 |
unpack_multirun_results: bool = False,
|
84 |
+
unpack_multirun_results_with_job_id: bool = False,
|
85 |
in_percent: bool = False,
|
86 |
reset_index: bool = False,
|
87 |
+
sort_columns: bool = False,
|
88 |
+
stringify_column_names: bool = False,
|
89 |
+
column_regex_blacklist: Optional[List[str]] = None,
|
90 |
+
column_regex_whitelist: Optional[List[str]] = None,
|
91 |
+
replace_in_col_names: Optional[List[Tuple[str, str]]] = None,
|
92 |
):
|
93 |
file_paths = get_file_paths(
|
94 |
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
95 |
)
|
96 |
data = read_nested_jsons(json_paths=file_paths)
|
97 |
|
98 |
+
job_id_col = get_job_id_col(data.columns)
|
99 |
+
|
100 |
if columns is not None:
|
101 |
+
columns_multi_index = [
|
102 |
+
tuple([part or np.nan for part in col.split("/")]) for col in columns
|
103 |
+
]
|
104 |
+
if unpack_multirun_results_with_job_id:
|
105 |
+
if job_id_col is None:
|
106 |
+
raise ValueError("Job ID column not found in the data.")
|
107 |
+
if job_id_col not in columns_multi_index:
|
108 |
+
columns_multi_index.append(job_id_col)
|
109 |
try:
|
110 |
+
available_cols = data.columns.tolist()
|
111 |
+
for col in columns_multi_index:
|
112 |
+
if col not in available_cols:
|
113 |
+
raise KeyError(f"Column {col} not found in the data.")
|
114 |
data_series = [data[col] for col in columns_multi_index]
|
115 |
except KeyError as e:
|
116 |
print(
|
|
|
142 |
for level in sorted(unique_column_levels, reverse=True):
|
143 |
data.columns = data.columns.droplevel(level)
|
144 |
|
145 |
+
if unpack_multirun_results or unpack_multirun_results_with_job_id:
|
146 |
index_names = list(data.index.names)
|
147 |
+
data_series_lists = data.copy()
|
148 |
+
job_ids = None
|
149 |
+
if job_id_col in data_series_lists.columns:
|
150 |
+
job_ids_series = data_series_lists.pop(job_id_col)
|
151 |
+
job_ids_frame = pd.DataFrame(pd.DataFrame.from_records(job_ids_series.values))
|
152 |
+
job_ids_frame.index = job_ids_series.index
|
153 |
+
# check that all rows are identical
|
154 |
+
if job_ids_frame.nunique().max():
|
155 |
+
job_ids = job_ids_frame.iloc[0]
|
156 |
+
else:
|
157 |
+
logger.warning(
|
158 |
+
"Job IDs are not identical across all rows. Cannot unpack "
|
159 |
+
"multirun results with job ids as columns."
|
160 |
+
)
|
161 |
+
|
162 |
+
while not isinstance(data_series_lists, pd.Series):
|
163 |
+
data_series_lists = data_series_lists.stack(future_stack=True)
|
164 |
+
data_series_lists = data_series_lists.dropna()
|
165 |
+
data = pd.DataFrame.from_records(data_series_lists.values, index=data_series_lists.index)
|
166 |
+
if job_ids is not None:
|
167 |
+
data.columns = job_ids
|
168 |
+
num_col_levels = data.index.nlevels - len(index_names)
|
169 |
+
for _ in range(num_col_levels):
|
170 |
+
data = data.unstack()
|
171 |
+
data.columns = data.columns.swaplevel(0, -1)
|
172 |
+
data = data.dropna(how="all", axis="columns")
|
173 |
|
174 |
# needs to happen before rounding, otherwise the rounding will be off
|
175 |
if in_percent:
|
176 |
+
float_columns = data.select_dtypes(include=["float64", "float32"]).columns
|
177 |
+
data[float_columns] = data[float_columns] * 100
|
178 |
|
179 |
if round_precision is not None:
|
180 |
data = data.round(round_precision)
|
181 |
|
182 |
# needs to happen before transposing
|
183 |
if format == "markdown_mean_and_std":
|
184 |
+
if data.columns.nlevels == 1:
|
185 |
+
data.columns = pd.MultiIndex.from_tuples([(col,) for col in data.columns.tolist()])
|
186 |
+
|
187 |
+
# get mean columns
|
188 |
+
mean_col_names = [col for col in data.columns if "mean" in col]
|
189 |
+
mean_columns = data[mean_col_names].copy()
|
190 |
+
# remove all "mean" from col names
|
191 |
+
mean_columns.columns = remove_part_from_multi_index(mean_columns.columns, "mean")
|
192 |
+
# get std columns
|
193 |
+
std_col_names = [col for col in data.columns if "std" in col]
|
194 |
+
std_columns = data[std_col_names].copy()
|
195 |
+
# remove all "std" from col names
|
196 |
+
std_columns.columns = remove_part_from_multi_index(std_columns.columns, "std")
|
197 |
+
# sanity check
|
198 |
+
if not mean_columns.columns.equals(std_columns.columns):
|
199 |
+
raise ValueError("Mean and std columns do not match.")
|
200 |
+
mean_and_std = mean_columns.astype(str) + " ± " + std_columns.astype(str)
|
201 |
+
mean_and_std.columns = [
|
202 |
+
("mean ± std",) + (tuple(col) if col != ((),) else ()) for col in mean_columns.columns
|
203 |
+
]
|
204 |
+
# remove mean and std columns from data
|
205 |
+
# we can not use drop because the columns is a multiindex that may contain NaNs
|
206 |
+
other_cols = [
|
207 |
+
col for col in data.columns if col not in set(mean_col_names + std_col_names)
|
208 |
+
]
|
209 |
+
data = data[other_cols]
|
210 |
+
# add mean and std columns to data
|
211 |
+
data = pd.concat([data, mean_and_std], axis=1)
|
212 |
+
if data.columns.nlevels == 1:
|
213 |
+
data.columns = data.columns.to_flat_index()
|
214 |
+
data.columns = [
|
215 |
+
"/".join(col) if isinstance(col, tuple) else col for col in data.columns
|
216 |
+
]
|
217 |
|
218 |
if transpose:
|
219 |
data = data.T
|
220 |
|
221 |
+
if sort_columns:
|
222 |
+
# sort columns to get a deterministic order
|
223 |
+
data = data.sort_index(axis=1)
|
224 |
+
|
225 |
+
if stringify_column_names:
|
226 |
+
# Convert MultiIndex columns to string representation
|
227 |
+
data.columns = data.columns.map(stringify)
|
228 |
+
|
229 |
+
if column_regex_blacklist is not None:
|
230 |
+
# Remove columns that match any of the regex patterns in the blacklist
|
231 |
+
for pattern in column_regex_blacklist:
|
232 |
+
data = data.loc[:, ~data.columns.str.contains(pattern, regex=True)]
|
233 |
+
|
234 |
+
if column_regex_whitelist is not None:
|
235 |
+
# keep only columns that match any of the regex patterns in the whitelist
|
236 |
+
data = data.loc[
|
237 |
+
:, data.columns.str.contains("|".join(column_regex_whitelist), regex=True)
|
238 |
+
]
|
239 |
+
|
240 |
+
if replace_in_col_names is not None:
|
241 |
+
for old_value, new_value in replace_in_col_names:
|
242 |
+
data.columns = data.columns.str.replace(old_value, new_value, regex=False)
|
243 |
+
|
244 |
+
else:
|
245 |
+
if column_regex_blacklist is not None:
|
246 |
+
logger.warning(
|
247 |
+
"Column regex blacklist is ignored when stringify_column_names is False."
|
248 |
+
)
|
249 |
+
if column_regex_whitelist is not None:
|
250 |
+
logger.warning(
|
251 |
+
"Column regex whitelist is ignored when stringify_column_names is False."
|
252 |
+
)
|
253 |
+
if replace_in_col_names is not None:
|
254 |
+
logger.warning(
|
255 |
+
"Replace in column names is ignored when stringify_column_names is False."
|
256 |
+
)
|
257 |
+
|
258 |
+
# remove empty rows
|
259 |
+
# get rows that contain only nan or "nan ± nan"
|
260 |
+
empty_rows = data.apply(
|
261 |
+
lambda row: all(
|
262 |
+
pd.isna(value) or (isinstance(value, str) and value == "nan ± nan") for value in row
|
263 |
+
),
|
264 |
+
axis=1,
|
265 |
+
)
|
266 |
+
data = data[~empty_rows]
|
267 |
+
|
268 |
if reset_index:
|
269 |
data = data.reset_index()
|
270 |
|
|
|
296 |
parser.add_argument(
|
297 |
"--unpack-multirun-results", action="store_true", help="Unpack multirun results"
|
298 |
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--unpack-multirun-results-with-job-id",
|
301 |
+
action="store_true",
|
302 |
+
help="Unpack multirun results with job ID",
|
303 |
+
)
|
304 |
parser.add_argument("--transpose", action="store_true", help="Transpose the table")
|
305 |
+
parser.add_argument(
|
306 |
+
"--sort-columns",
|
307 |
+
action="store_true",
|
308 |
+
help="Sort the columns of the combined job returns",
|
309 |
+
)
|
310 |
parser.add_argument(
|
311 |
"--round-precision",
|
312 |
type=int,
|
|
|
318 |
parser.add_argument(
|
319 |
"--reset-index", action="store_true", help="Reset the index of the combined job returns"
|
320 |
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--stringify-column-names",
|
323 |
+
action="store_true",
|
324 |
+
help="Stringify the column names of the combined job returns (useful for multi-index columns)",
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--column-regex-blacklist",
|
328 |
+
type=str,
|
329 |
+
nargs="+",
|
330 |
+
default=None,
|
331 |
+
help="List of regex patterns to match column names. "
|
332 |
+
"Columns that match any of the patterns will be removed.",
|
333 |
+
)
|
334 |
+
parser.add_argument(
|
335 |
+
"--column-regex-whitelist",
|
336 |
+
type=str,
|
337 |
+
nargs="+",
|
338 |
+
default=None,
|
339 |
+
help="List of regex patterns to match column names. "
|
340 |
+
"Only columns that match any of the patterns will be kept.",
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--replace-in-col-names",
|
344 |
+
type=lambda s: s.split(":", 1),
|
345 |
+
nargs="+",
|
346 |
+
default=None,
|
347 |
+
help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
|
348 |
+
)
|
349 |
parser.add_argument(
|
350 |
"--format",
|
351 |
type=str,
|
src/analysis/common.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import json
|
2 |
-
from typing import Dict, List, Optional
|
3 |
|
4 |
import pandas as pd
|
5 |
|
@@ -26,21 +26,42 @@ def read_nested_json(path: str) -> pd.DataFrame:
|
|
26 |
|
27 |
|
28 |
def read_nested_jsons(
|
29 |
-
json_paths:
|
30 |
default_key_values: Optional[Dict[str, str]] = None,
|
31 |
column_level_names: Optional[List[str]] = None,
|
32 |
) -> pd.DataFrame:
|
33 |
-
|
34 |
-
dfs = [read_nested_json(json_paths[identifier_str]) for identifier_str in identifier_strings]
|
35 |
new_index_levels = pd.MultiIndex.from_frame(
|
36 |
pd.DataFrame(
|
37 |
[
|
38 |
parse_identifier(identifier_str, default_key_values or {})
|
39 |
-
for identifier_str in
|
40 |
]
|
41 |
)
|
42 |
)
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
dfs_concat.columns = pd.MultiIndex.from_tuples(
|
45 |
[col.split("/") for col in dfs_concat.columns], names=column_level_names
|
46 |
)
|
|
|
1 |
import json
|
2 |
+
from typing import Dict, List, Optional, Tuple
|
3 |
|
4 |
import pandas as pd
|
5 |
|
|
|
26 |
|
27 |
|
28 |
def read_nested_jsons(
|
29 |
+
json_paths: List[Tuple[str, str]],
|
30 |
default_key_values: Optional[Dict[str, str]] = None,
|
31 |
column_level_names: Optional[List[str]] = None,
|
32 |
) -> pd.DataFrame:
|
33 |
+
dfs = [read_nested_json(json_path) for identifier_str, json_path in json_paths]
|
|
|
34 |
new_index_levels = pd.MultiIndex.from_frame(
|
35 |
pd.DataFrame(
|
36 |
[
|
37 |
parse_identifier(identifier_str, default_key_values or {})
|
38 |
+
for identifier_str, _ in json_paths
|
39 |
]
|
40 |
)
|
41 |
)
|
42 |
+
if len(set(list(new_index_levels))) == len(list(new_index_levels)):
|
43 |
+
dfs_concat = pd.concat(
|
44 |
+
dfs, keys=list(new_index_levels), names=new_index_levels.names, axis=0
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
dfs_new = []
|
48 |
+
ids_unique = []
|
49 |
+
for identifier_str in new_index_levels:
|
50 |
+
if identifier_str not in ids_unique:
|
51 |
+
ids_unique.append(identifier_str)
|
52 |
+
# first combine the dataframes with same ids along the columns
|
53 |
+
for identifier_str in ids_unique:
|
54 |
+
dfs_with_id = [df for df, idx in zip(dfs, new_index_levels) if idx == identifier_str]
|
55 |
+
# assert that all columns are distinct
|
56 |
+
if len(set([tuple(col) for df in dfs_with_id for col in df.columns])) != sum(
|
57 |
+
[len(df.columns) for df in dfs_with_id]
|
58 |
+
):
|
59 |
+
raise ValueError(
|
60 |
+
"There are duplicate columns across the dataframes with the same identifier."
|
61 |
+
)
|
62 |
+
dfs_id_concat = pd.concat(dfs_with_id, axis=1)
|
63 |
+
dfs_new.append(dfs_id_concat)
|
64 |
+
dfs_concat = pd.concat(dfs_new, keys=ids_unique, names=new_index_levels.names, axis=0)
|
65 |
dfs_concat.columns = pd.MultiIndex.from_tuples(
|
66 |
[col.split("/") for col in dfs_concat.columns], names=column_level_names
|
67 |
)
|
src/analysis/compare_job_returns.py
CHANGED
@@ -173,7 +173,7 @@ def combine_job_returns_and_plot(
|
|
173 |
|
174 |
if job_return_paths is not None:
|
175 |
df_all = read_nested_jsons(
|
176 |
-
json_paths=job_return_paths,
|
177 |
default_key_values=default_key_values,
|
178 |
column_level_names=column_level_names,
|
179 |
)
|
|
|
173 |
|
174 |
if job_return_paths is not None:
|
175 |
df_all = read_nested_jsons(
|
176 |
+
json_paths=list(job_return_paths.items()),
|
177 |
default_key_values=default_key_values,
|
178 |
column_level_names=column_level_names,
|
179 |
)
|
src/analysis/format_metric_results.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
from pie_modules.utils import flatten_dict
|
9 |
+
|
10 |
+
|
11 |
+
def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series:
|
12 |
+
if s is None or s.strip() == "" or s == "None":
|
13 |
+
return pd.Series()
|
14 |
+
return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts)))
|
15 |
+
|
16 |
+
|
17 |
+
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
|
18 |
+
parts = path_and_maybe_id.split(separator, 1)
|
19 |
+
if len(parts) == 1:
|
20 |
+
return None, parts[0]
|
21 |
+
return parts[0], parts[1]
|
22 |
+
|
23 |
+
|
24 |
+
def load_data_from_json(path: str | Path) -> pd.DataFrame:
|
25 |
+
with open(path, "r") as f:
|
26 |
+
data_json = json.load(f)
|
27 |
+
data_flat = flatten_dict(data_json)
|
28 |
+
return pd.DataFrame(data_flat)
|
29 |
+
|
30 |
+
|
31 |
+
def main(
|
32 |
+
path: str | Path,
|
33 |
+
remove_col_prefix: str | None = None,
|
34 |
+
sparse_col_prefix: str | None = None,
|
35 |
+
tail_cols: list[str] | None = None,
|
36 |
+
sort_cols: list[str] | None = None,
|
37 |
+
split_col: str | None = None,
|
38 |
+
replace_in_col_names: list[tuple[str, str]] | None = None,
|
39 |
+
round_precision: int | None = None,
|
40 |
+
in_percent: bool = False,
|
41 |
+
common_prefix_separator: str | None = None,
|
42 |
+
column_regex_blacklist: list[str] | None = None,
|
43 |
+
column_regex_whitelist: list[str] | None = None,
|
44 |
+
format: str = "markdown",
|
45 |
+
) -> None:
|
46 |
+
|
47 |
+
if str(path).lower().endswith(".json"):
|
48 |
+
result = load_data_from_json(path)
|
49 |
+
elif str(path).lower().endswith(".txt"):
|
50 |
+
with open(path, "r") as f:
|
51 |
+
index_data = [separate_path_and_id(line.strip()) for line in f.readlines()]
|
52 |
+
data_list = []
|
53 |
+
for meta_id, meta_path in index_data:
|
54 |
+
data = load_data_from_json(os.path.join(meta_path, "job_return_value.json"))
|
55 |
+
if meta_id is not None:
|
56 |
+
job_id_prefix = meta_id.replace(",", "-")
|
57 |
+
data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str)
|
58 |
+
data = data.set_index("job_id")
|
59 |
+
data_list.append(data)
|
60 |
+
result = pd.concat(data_list, axis=1).reset_index()
|
61 |
+
else:
|
62 |
+
raise ValueError("Unsupported file format. Please provide a .json or .txt file.")
|
63 |
+
|
64 |
+
if remove_col_prefix is not None:
|
65 |
+
result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True)
|
66 |
+
|
67 |
+
if sparse_col_prefix is not None:
|
68 |
+
# get all columns that contain just one not-nan value
|
69 |
+
# number_of_non_nan_values = len(df) - df.isna().sum()
|
70 |
+
# df_sparse = df.loc[:, number_of_non_nan_values == 1]
|
71 |
+
sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)]
|
72 |
+
other_cols = [col for col in result.columns if col not in sparse_cols]
|
73 |
+
|
74 |
+
value_col = f"{sparse_col_prefix}value"
|
75 |
+
name_col = f"{sparse_col_prefix}name"
|
76 |
+
result = result.melt(
|
77 |
+
id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col
|
78 |
+
).dropna(
|
79 |
+
subset=[value_col]
|
80 |
+
) # keep rows with a value
|
81 |
+
# strip the "f1-" prefix, leaving just the numeric threshold
|
82 |
+
result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True)
|
83 |
+
# convert the column to numeric (if possible)
|
84 |
+
try:
|
85 |
+
result[name_col] = pd.to_numeric(result[name_col])
|
86 |
+
except ValueError:
|
87 |
+
# if it fails, just keep it as a string
|
88 |
+
pass
|
89 |
+
|
90 |
+
if split_col is not None:
|
91 |
+
new_frame = result[split_col].apply(str2record)
|
92 |
+
result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1)
|
93 |
+
|
94 |
+
if in_percent:
|
95 |
+
float_columns = result.select_dtypes(include=["float64", "float32"]).columns
|
96 |
+
result[float_columns] = result[float_columns] * 100
|
97 |
+
|
98 |
+
if round_precision is not None:
|
99 |
+
# round all columns to the given precision
|
100 |
+
result = result.round(round_precision)
|
101 |
+
|
102 |
+
if common_prefix_separator is not None:
|
103 |
+
# remove common prefix from values in all string columns
|
104 |
+
obj_columns = result.select_dtypes(include=["object"]).columns
|
105 |
+
for obj_col in obj_columns:
|
106 |
+
# get the common prefix
|
107 |
+
common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist())
|
108 |
+
# find last occurrence of the common_prefix_separator
|
109 |
+
last_occurrence = common_prefix.rfind(common_prefix_separator)
|
110 |
+
if last_occurrence != -1:
|
111 |
+
# truncate the common prefix after the last occurrence of the separator
|
112 |
+
common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)]
|
113 |
+
# remove the common prefix (including the separator) from the column
|
114 |
+
result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True)
|
115 |
+
|
116 |
+
# sort columns to get a deterministic order
|
117 |
+
result = result.sort_index(axis=1)
|
118 |
+
|
119 |
+
if tail_cols is not None:
|
120 |
+
front_cols = [c for c in result.columns if c not in tail_cols]
|
121 |
+
result = result[front_cols + tail_cols]
|
122 |
+
|
123 |
+
if sort_cols is not None:
|
124 |
+
result = result.sort_values(sort_cols)
|
125 |
+
# also move the sort columns to the front
|
126 |
+
result = result[sort_cols + [c for c in result.columns if c not in sort_cols]]
|
127 |
+
|
128 |
+
if column_regex_blacklist is not None:
|
129 |
+
# remove columns that match any of the regex patterns in the blacklist
|
130 |
+
for pattern in column_regex_blacklist:
|
131 |
+
result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)]
|
132 |
+
|
133 |
+
if column_regex_whitelist is not None:
|
134 |
+
# keep only columns that match any of the regex patterns in the whitelist
|
135 |
+
result = result.loc[
|
136 |
+
:, result.columns.str.contains("|".join(column_regex_whitelist), regex=True)
|
137 |
+
]
|
138 |
+
|
139 |
+
if replace_in_col_names is not None:
|
140 |
+
for old_value, new_value in replace_in_col_names:
|
141 |
+
result.columns = result.columns.str.replace(old_value, new_value, regex=False)
|
142 |
+
|
143 |
+
if format == "markdown":
|
144 |
+
result_str = result.to_markdown(index=False)
|
145 |
+
elif format == "csv":
|
146 |
+
result_str = result.to_csv(index=False)
|
147 |
+
elif format == "tsv":
|
148 |
+
result_str = result.to_csv(index=False, sep="\t")
|
149 |
+
elif format == "json":
|
150 |
+
result_str = result.to_json(orient="records", lines=True)
|
151 |
+
else:
|
152 |
+
raise ValueError(
|
153 |
+
f"Unsupported format: {format}. Supported formats are: markdown, csv, json."
|
154 |
+
)
|
155 |
+
|
156 |
+
print(result_str)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
"""
|
161 |
+
Example usage:
|
162 |
+
|
163 |
+
python src/analysis/format_metric_results.py \
|
164 |
+
logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \
|
165 |
+
--remove-col-prefix train/ \
|
166 |
+
--sparse-col-prefix f1- \
|
167 |
+
--split-col job_id \
|
168 |
+
--tail-cols num_positives num_total \
|
169 |
+
--sort-cols experiment model \
|
170 |
+
--round-precision 4
|
171 |
+
"""
|
172 |
+
|
173 |
+
parser = argparse.ArgumentParser(
|
174 |
+
description="Process a JSON file containing metric results (from multirun) and print as Markdown table."
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"path",
|
178 |
+
type=str,
|
179 |
+
help="Path to the JSON file to process. The JSON file is expected to contain "
|
180 |
+
"a (maybe nested) dictionary where each leave entry is a list of values with "
|
181 |
+
"the same length.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--remove-col-prefix",
|
185 |
+
type=str,
|
186 |
+
default=None,
|
187 |
+
help="Prefix to remove from column names.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--sparse-col-prefix",
|
191 |
+
type=str,
|
192 |
+
default=None,
|
193 |
+
help="Prefix of sparse columns. All sparse columns will be melted into "
|
194 |
+
"two columns: <prefix>name and <prefix>value. The name column will "
|
195 |
+
"be converted to numeric if possible.",
|
196 |
+
)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--split-col",
|
200 |
+
type=str,
|
201 |
+
default=None,
|
202 |
+
help="Column to split into multiple columns. The format of the "
|
203 |
+
"column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...",
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--tail-cols",
|
207 |
+
type=str,
|
208 |
+
nargs="+",
|
209 |
+
default=None,
|
210 |
+
help="Columns to move to the end.",
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--sort-cols",
|
214 |
+
type=str,
|
215 |
+
nargs="+",
|
216 |
+
default=None,
|
217 |
+
help="Columns to sort by (they will be moved to the front).",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--replace-in-col-names",
|
221 |
+
type=lambda s: s.split(":", 1),
|
222 |
+
nargs="+",
|
223 |
+
default=None,
|
224 |
+
help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--round-precision",
|
228 |
+
type=int,
|
229 |
+
default=None,
|
230 |
+
help="Number of decimal places to round to.",
|
231 |
+
)
|
232 |
+
parser.add_argument(
|
233 |
+
"--in-percent",
|
234 |
+
action="store_true",
|
235 |
+
default=False,
|
236 |
+
help="If set, all float columns will be multiplied by 100 to convert them to percentages.",
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--common-prefix-separator",
|
240 |
+
type=str,
|
241 |
+
default=None,
|
242 |
+
help="For all string columns, remove the common prefix up to the last occurrence of this separator.",
|
243 |
+
)
|
244 |
+
parser.add_argument(
|
245 |
+
"--column-regex-blacklist",
|
246 |
+
type=str,
|
247 |
+
nargs="+",
|
248 |
+
default=None,
|
249 |
+
help="List of regex patterns to match column names. "
|
250 |
+
"Columns that match any of the patterns will be removed.",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--column-regex-whitelist",
|
254 |
+
type=str,
|
255 |
+
nargs="+",
|
256 |
+
default=None,
|
257 |
+
help="List of regex patterns to match column names. "
|
258 |
+
"Only columns that match any of the patterns will be kept.",
|
259 |
+
)
|
260 |
+
parser.add_argument(
|
261 |
+
"--format",
|
262 |
+
type=str,
|
263 |
+
default="markdown",
|
264 |
+
choices=["markdown", "csv", "tsv", "json"],
|
265 |
+
help="Format to print the result in. Supported formats are: markdown, csv, json.",
|
266 |
+
)
|
267 |
+
|
268 |
+
kwargs = vars(parser.parse_args())
|
269 |
+
main(**kwargs)
|
src/analysis/get_json_field_as_string.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def main(
|
5 |
+
paths: list[str],
|
6 |
+
field: list[str],
|
7 |
+
format: str = "plain",
|
8 |
+
) -> None:
|
9 |
+
result = []
|
10 |
+
for path in paths:
|
11 |
+
with open(path, "r") as f:
|
12 |
+
data = json.load(f)
|
13 |
+
value = data
|
14 |
+
for key in field:
|
15 |
+
value = value.get(key)
|
16 |
+
if not isinstance(value, list):
|
17 |
+
value = [value]
|
18 |
+
result.extend(value)
|
19 |
+
if format == "plain":
|
20 |
+
print(",".join(map(str, result)))
|
21 |
+
elif format == "python":
|
22 |
+
result_str = str(result)
|
23 |
+
print(result_str.replace(" ", ""))
|
24 |
+
else:
|
25 |
+
raise ValueError(f"Unknown format: {format}")
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser(
|
32 |
+
description="Get a field from one or more JSON files and print to stdout."
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"paths",
|
36 |
+
type=lambda x: x.split(","),
|
37 |
+
help="Comma-separated list of paths to the JSON files to process.",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--field",
|
41 |
+
type=str,
|
42 |
+
required=True,
|
43 |
+
nargs="+",
|
44 |
+
help="Field to extract from the JSON files. Can be a nested field by providing multiple entries.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--format",
|
48 |
+
type=str,
|
49 |
+
default="plain",
|
50 |
+
choices=["plain", "python"],
|
51 |
+
)
|
52 |
+
|
53 |
+
args = parser.parse_args()
|
54 |
+
kwargs = vars(args)
|
55 |
+
main(**kwargs)
|
src/analysis/show_inference_params_on_quality_and_throughput.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.express as px
|
8 |
+
|
9 |
+
|
10 |
+
def get_col_name(col: str) -> str:
|
11 |
+
parts = [part[1:-1] for part in col[1:-1].split(", ") if part[1:-1] != ""]
|
12 |
+
return parts[-1]
|
13 |
+
|
14 |
+
|
15 |
+
def get_idx_entry(s: str, keep_only_last_part: bool = False) -> Tuple[str, str]:
|
16 |
+
k, v = s.split("=", 1)
|
17 |
+
if keep_only_last_part:
|
18 |
+
k = k.split(".")[-1]
|
19 |
+
return k, v
|
20 |
+
|
21 |
+
|
22 |
+
def get_idx_dict(job_id: str, keep_only_last_part: bool = False) -> Dict[str, str]:
|
23 |
+
return dict(
|
24 |
+
get_idx_entry(part, keep_only_last_part=keep_only_last_part) for part in job_id.split("-")
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def unflatten_index(
|
29 |
+
index: Iterable[str],
|
30 |
+
keep_only_last_part: bool = False,
|
31 |
+
dtypes: Optional[Dict[str, Any]] = None,
|
32 |
+
) -> pd.MultiIndex:
|
33 |
+
as_df = pd.DataFrame.from_records(
|
34 |
+
[get_idx_dict(idx, keep_only_last_part=keep_only_last_part) for idx in index]
|
35 |
+
)
|
36 |
+
if dtypes is not None:
|
37 |
+
dtypes_valid = {col: dtype for col, dtype in dtypes.items() if col in as_df.columns}
|
38 |
+
as_df = as_df.astype(dtypes_valid)
|
39 |
+
return pd.MultiIndex.from_frame(as_df.convert_dtypes())
|
40 |
+
|
41 |
+
|
42 |
+
def col_to_str(col_entries: Iterable[str], names: Iterable[Optional[str]], sep: str) -> str:
|
43 |
+
return sep.join(
|
44 |
+
[
|
45 |
+
f"{name}={col_entry}" if name is not None else col_entry
|
46 |
+
for col_entry, name in zip(col_entries, names)
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def flatten_index(index: pd.MultiIndex, names: Optional[List[Optional[str]]] = None) -> pd.Index:
|
52 |
+
names = names or index.names
|
53 |
+
if names is None:
|
54 |
+
raise ValueError("names must be provided if index has no names")
|
55 |
+
return pd.Index([col_to_str(col, names=names, sep=",") for col in index])
|
56 |
+
|
57 |
+
|
58 |
+
def prepare_quality_and_throughput_dfs(
|
59 |
+
metric_data_path: str,
|
60 |
+
job_return_value_path: str,
|
61 |
+
char_total: int,
|
62 |
+
index_dtypes: Optional[Dict[str, Any]] = None,
|
63 |
+
job_id_prefix: Optional[str] = None,
|
64 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
65 |
+
|
66 |
+
with open(metric_data_path) as f:
|
67 |
+
data = json.load(f)
|
68 |
+
|
69 |
+
# save result from above command in "data" (use only last ouf the output line!)
|
70 |
+
df = pd.DataFrame.from_dict(data)
|
71 |
+
df.columns = [get_col_name(col) for col in df.columns]
|
72 |
+
f1_series = df.set_index([col for col in df.columns if col != "f1"])["f1"]
|
73 |
+
f1_df = f1_series.apply(lambda x: pd.Series(x)).T
|
74 |
+
|
75 |
+
with open(job_return_value_path) as f:
|
76 |
+
job_return_value = json.load(f)
|
77 |
+
|
78 |
+
job_ids = job_return_value["job_id"]
|
79 |
+
if job_id_prefix is not None:
|
80 |
+
job_ids = [
|
81 |
+
f"{job_id_prefix},{job_id}" if job_id.strip() != "" else job_id_prefix
|
82 |
+
for job_id in job_ids
|
83 |
+
]
|
84 |
+
index = unflatten_index(
|
85 |
+
job_ids,
|
86 |
+
keep_only_last_part=True,
|
87 |
+
dtypes=index_dtypes,
|
88 |
+
)
|
89 |
+
prediction_time_series = pd.Series(
|
90 |
+
job_return_value["prediction_time"], index=index, name="prediction_time"
|
91 |
+
)
|
92 |
+
f1_df.index = prediction_time_series.index
|
93 |
+
|
94 |
+
k_chars_per_s = char_total / (prediction_time_series * 1000)
|
95 |
+
k_chars_per_s.name = "1k_chars_per_s"
|
96 |
+
|
97 |
+
return f1_df, k_chars_per_s
|
98 |
+
|
99 |
+
|
100 |
+
def get_pareto_front_mask(df: pd.DataFrame, x_col: str, y_col: str) -> pd.Series:
|
101 |
+
"""
|
102 |
+
Return a boolean mask indicating which rows belong to the Pareto front.
|
103 |
+
In this version, we assume you want to maximize both x_col and y_col.
|
104 |
+
|
105 |
+
A point A is said to dominate point B if:
|
106 |
+
A[x_col] >= B[x_col] AND
|
107 |
+
A[y_col] >= B[y_col] AND
|
108 |
+
at least one is strictly greater.
|
109 |
+
Then B is not on the Pareto front.
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
df : pd.DataFrame
|
114 |
+
DataFrame containing the data points.
|
115 |
+
x_col : str
|
116 |
+
Name of the column to treat as the first objective (maximize).
|
117 |
+
y_col : str
|
118 |
+
Name of the column to treat as the second objective (maximize).
|
119 |
+
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
pd.Series
|
123 |
+
A boolean Series (aligned with df.index) where True means
|
124 |
+
the row is on the Pareto front.
|
125 |
+
"""
|
126 |
+
# Extract the relevant columns as a NumPy array for speed.
|
127 |
+
data = df[[x_col, y_col]].values
|
128 |
+
n = len(data)
|
129 |
+
is_dominated = np.zeros(n, dtype=bool)
|
130 |
+
|
131 |
+
for i in range(n):
|
132 |
+
# If it's already marked dominated, skip checks
|
133 |
+
if is_dominated[i]:
|
134 |
+
continue
|
135 |
+
|
136 |
+
for j in range(n):
|
137 |
+
if i == j:
|
138 |
+
continue
|
139 |
+
# Check if j dominates i
|
140 |
+
if (
|
141 |
+
data[j, 0] >= data[i, 0]
|
142 |
+
and data[j, 1] >= data[i, 1]
|
143 |
+
and (data[j, 0] > data[i, 0] or data[j, 1] > data[i, 1])
|
144 |
+
):
|
145 |
+
is_dominated[i] = True
|
146 |
+
break
|
147 |
+
|
148 |
+
# Return True for points not dominated by any other
|
149 |
+
return pd.Series(~is_dominated, index=df.index)
|
150 |
+
|
151 |
+
|
152 |
+
def main(
|
153 |
+
job_return_value_path_test: List[str],
|
154 |
+
job_return_value_path_val: List[str],
|
155 |
+
metric_data_path_test: List[str],
|
156 |
+
metric_data_path_val: List[str],
|
157 |
+
char_total_test: int,
|
158 |
+
char_total_val: int,
|
159 |
+
job_id_prefixes: Optional[List[str]] = None,
|
160 |
+
metric_filters: Optional[List[str]] = None,
|
161 |
+
index_filters: Optional[List[str]] = None,
|
162 |
+
index_blacklist: Optional[List[str]] = None,
|
163 |
+
label_mapping: Optional[Dict[str, str]] = None,
|
164 |
+
plot_method: str = "line", # can be "scatter" or "line"
|
165 |
+
pareto_front: bool = False,
|
166 |
+
show_as: str = "figure",
|
167 |
+
columns: Optional[List[str]] = None,
|
168 |
+
color_column: Optional[str] = None,
|
169 |
+
):
|
170 |
+
label_mapping = label_mapping or {}
|
171 |
+
if job_id_prefixes is not None:
|
172 |
+
if len(job_id_prefixes) != len(job_return_value_path_test):
|
173 |
+
raise ValueError(
|
174 |
+
f"job_id_prefixes ({len(job_id_prefixes)}) and "
|
175 |
+
f"job_return_value_path_test ({len(job_return_value_path_test)}) "
|
176 |
+
f"must have the same length"
|
177 |
+
)
|
178 |
+
# replace empty strings with None
|
179 |
+
job_id_prefixes_with_none = [
|
180 |
+
job_id_prefix if job_id_prefix != "" else None for job_id_prefix in job_id_prefixes
|
181 |
+
]
|
182 |
+
else:
|
183 |
+
job_id_prefixes_with_none = [None] * len(job_return_value_path_test)
|
184 |
+
|
185 |
+
# combine input data for test and val
|
186 |
+
char_total = {"test": char_total_test, "val": char_total_val}
|
187 |
+
metric_data_path = {"test": metric_data_path_test, "val": metric_data_path_val}
|
188 |
+
job_return_value_path = {"test": job_return_value_path_test, "val": job_return_value_path_val}
|
189 |
+
# prepare dataframes
|
190 |
+
common_kwargs = dict(
|
191 |
+
index_dtypes={
|
192 |
+
"max_argument_distance": int,
|
193 |
+
"max_length": int,
|
194 |
+
"num_beams": int,
|
195 |
+
}
|
196 |
+
)
|
197 |
+
f1_df_list: Dict[str, List[pd.DataFrame]] = {"test": [], "val": []}
|
198 |
+
k_chars_per_s_list: Dict[str, List[pd.Series]] = {"test": [], "val": []}
|
199 |
+
for split in metric_data_path:
|
200 |
+
if len(metric_data_path[split]) != len(job_return_value_path[split]):
|
201 |
+
raise ValueError(
|
202 |
+
f"metric_data_path[{split}] ({len(metric_data_path[split])}) and "
|
203 |
+
f"job_return_value_path[{split}] ({len(job_return_value_path[split])}) "
|
204 |
+
f"must have the same length"
|
205 |
+
)
|
206 |
+
for current_metric_data_path, current_job_return_value_path, job_id_prefix in zip(
|
207 |
+
metric_data_path[split], job_return_value_path[split], job_id_prefixes_with_none
|
208 |
+
):
|
209 |
+
current_f1_df, current_k_chars_per_s = prepare_quality_and_throughput_dfs(
|
210 |
+
current_metric_data_path,
|
211 |
+
current_job_return_value_path,
|
212 |
+
char_total=char_total[split],
|
213 |
+
job_id_prefix=job_id_prefix,
|
214 |
+
**common_kwargs,
|
215 |
+
)
|
216 |
+
f1_df_list[split].append(current_f1_df)
|
217 |
+
k_chars_per_s_list[split].append(current_k_chars_per_s)
|
218 |
+
f1_df_dict = {split: pd.concat(f1_df_list[split], axis=0) for split in f1_df_list}
|
219 |
+
k_chars_per_s_dict = {
|
220 |
+
split: pd.concat(k_chars_per_s_list[split], axis=0) for split in k_chars_per_s_list
|
221 |
+
}
|
222 |
+
|
223 |
+
# combine dataframes for test and val
|
224 |
+
f1_df = pd.concat(f1_df_dict, names=["split"] + f1_df_dict["test"].index.names)
|
225 |
+
f1_df.columns = [col_to_str(col, names=f1_df.columns.names, sep=",") for col in f1_df.columns]
|
226 |
+
k_chars_per_s = pd.concat(
|
227 |
+
k_chars_per_s_dict,
|
228 |
+
names=["split"] + k_chars_per_s_dict["test"].index.names,
|
229 |
+
)
|
230 |
+
|
231 |
+
# combine quality and throughput data
|
232 |
+
df_plot = pd.concat([f1_df, k_chars_per_s], axis=1)
|
233 |
+
df_plot = (
|
234 |
+
df_plot.reset_index()
|
235 |
+
.set_index(list(f1_df.index.names) + [k_chars_per_s.name])
|
236 |
+
.unstack("split")
|
237 |
+
)
|
238 |
+
df_plot.columns = flatten_index(df_plot.columns, names=[None, "split"])
|
239 |
+
|
240 |
+
# remove all columns that are not needed
|
241 |
+
if metric_filters is not None:
|
242 |
+
for fil in metric_filters:
|
243 |
+
df_plot.drop(columns=[col for col in df_plot.columns if fil not in col], inplace=True)
|
244 |
+
df_plot.columns = [col.replace(fil, "") for col in df_plot.columns]
|
245 |
+
|
246 |
+
# flatten the columns
|
247 |
+
df_plot.columns = [
|
248 |
+
",".join([part for part in col.split(",") if part != ""]) for col in df_plot.columns
|
249 |
+
]
|
250 |
+
|
251 |
+
v: Any
|
252 |
+
if index_filters is not None:
|
253 |
+
for k_v in index_filters:
|
254 |
+
k, v = k_v.split("=")
|
255 |
+
if k in common_kwargs["index_dtypes"]:
|
256 |
+
v = common_kwargs["index_dtypes"][k](v)
|
257 |
+
df_plot = df_plot.xs(v, level=k, axis=0)
|
258 |
+
|
259 |
+
if index_blacklist is not None:
|
260 |
+
for k_v in index_blacklist:
|
261 |
+
k, v = k_v.split("=")
|
262 |
+
if k in common_kwargs["index_dtypes"]:
|
263 |
+
v = common_kwargs["index_dtypes"][k](v)
|
264 |
+
df_plot = df_plot.drop(v, level=k, axis=0)
|
265 |
+
|
266 |
+
if columns is not None:
|
267 |
+
df_plot = df_plot[columns]
|
268 |
+
|
269 |
+
x = "1k_chars_per_s"
|
270 |
+
y = df_plot.columns
|
271 |
+
|
272 |
+
if pareto_front:
|
273 |
+
for col in y:
|
274 |
+
current_data = df_plot[col].dropna().reset_index(x).copy()
|
275 |
+
pareto_front_mask = get_pareto_front_mask(current_data, x_col=x, y_col=col)
|
276 |
+
current_data.loc[~pareto_front_mask, col] = np.nan
|
277 |
+
current_data_reset = current_data.reset_index().set_index(df_plot.index.names)
|
278 |
+
df_plot[col] = current_data_reset[col]
|
279 |
+
|
280 |
+
# remove nan rows
|
281 |
+
df_plot = df_plot.dropna(how="all")
|
282 |
+
|
283 |
+
# plot
|
284 |
+
# Create a custom color sequence (concatenating multiple palettes if needed)
|
285 |
+
custom_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
|
286 |
+
|
287 |
+
text_cols = list(df_plot.index.names)
|
288 |
+
text_cols.remove(x)
|
289 |
+
df_plot_reset = df_plot.reset_index()
|
290 |
+
if len(text_cols) > 1:
|
291 |
+
df_plot_reset[",".join(text_cols)] = (
|
292 |
+
df_plot_reset[text_cols].astype(str).agg(", ".join, axis=1)
|
293 |
+
)
|
294 |
+
text_col = ",".join(text_cols)
|
295 |
+
|
296 |
+
if show_as == "figure":
|
297 |
+
_plot_method = getattr(px, plot_method)
|
298 |
+
df_plot_sorted = df_plot_reset.sort_values(by=x)
|
299 |
+
fig = _plot_method(
|
300 |
+
df_plot_sorted,
|
301 |
+
x=x,
|
302 |
+
y=y,
|
303 |
+
text=text_col if plot_method != "scatter" else None,
|
304 |
+
color=color_column,
|
305 |
+
color_discrete_sequence=custom_colors,
|
306 |
+
hover_data=text_cols,
|
307 |
+
)
|
308 |
+
|
309 |
+
# set connectgaps to True to connect the lines
|
310 |
+
fig.update_traces(connectgaps=True)
|
311 |
+
|
312 |
+
legend_title = "Evaluation Setup"
|
313 |
+
if metric_filters:
|
314 |
+
whitelist_filters_mapped = [label_mapping.get(fil, fil) for fil in metric_filters]
|
315 |
+
legend_title += f" ({', '.join(whitelist_filters_mapped)})"
|
316 |
+
|
317 |
+
text_cols_mapped = [label_mapping.get(col, col) for col in text_cols]
|
318 |
+
title = f"Impact of {', '.join(text_cols_mapped)} on Prediction Quality and Throughput"
|
319 |
+
if index_filters:
|
320 |
+
index_filters_mapped = [label_mapping.get(fil, fil) for fil in index_filters]
|
321 |
+
title += f" ({', '.join(index_filters_mapped)})"
|
322 |
+
if pareto_front:
|
323 |
+
title += " (Pareto Front)"
|
324 |
+
|
325 |
+
fig.update_layout(
|
326 |
+
xaxis_title="Throughput (1k chars/s)",
|
327 |
+
yaxis_title="Quality (F1)",
|
328 |
+
title=title,
|
329 |
+
# center the title
|
330 |
+
title_x=0.2,
|
331 |
+
# black title
|
332 |
+
title_font=dict(color="black"),
|
333 |
+
# change legend title
|
334 |
+
legend_title=legend_title,
|
335 |
+
font_family="Computer Modern",
|
336 |
+
# white background
|
337 |
+
plot_bgcolor="white",
|
338 |
+
paper_bgcolor="white",
|
339 |
+
)
|
340 |
+
update_axes_kwargs = dict(
|
341 |
+
tickfont=dict(color="black"),
|
342 |
+
title_font=dict(color="black"),
|
343 |
+
ticks="inside", # ensure tick markers are drawn
|
344 |
+
tickcolor="black",
|
345 |
+
tickwidth=1,
|
346 |
+
ticklen=10,
|
347 |
+
linecolor="black",
|
348 |
+
# show grid
|
349 |
+
gridcolor="lightgray",
|
350 |
+
)
|
351 |
+
fig.update_yaxes(**update_axes_kwargs)
|
352 |
+
fig.update_xaxes(**update_axes_kwargs)
|
353 |
+
|
354 |
+
fig.show()
|
355 |
+
elif show_as == "markdown":
|
356 |
+
# Print the DataFrame as a Markdown table
|
357 |
+
print(df_plot_reset.to_markdown(index=False, floatfmt=".4f"))
|
358 |
+
elif show_as == "json":
|
359 |
+
# Print the DataFrame as a JSON object
|
360 |
+
print(df_plot_reset.to_json(orient="columns", indent=4))
|
361 |
+
else:
|
362 |
+
raise ValueError(f"Unknown show_as value: {show_as}. Use 'figure', 'markdown' or 'json'.")
|
363 |
+
|
364 |
+
|
365 |
+
if __name__ == "__main__":
|
366 |
+
|
367 |
+
"""
|
368 |
+
# Example usage 1 (pipeline model, data from data source: https://github.com/ArneBinder/pie-document-level/issues/388#issuecomment-2752829257):
|
369 |
+
python src/analysis/show_inference_params_on_quality_and_throughput.py \
|
370 |
+
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-26_01-31-05/job_return_value.json \
|
371 |
+
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-26_16-49-36/job_return_value.json \
|
372 |
+
--metric-data-path-test data/evaluation/argumentation_structure/inference_pipeline_test.json \
|
373 |
+
--metric-data-path-val data/evaluation/argumentation_structure/inference_pipeline_validation.json \
|
374 |
+
--metric-filters task=are discont_comp=true split=val
|
375 |
+
|
376 |
+
# Example usage 2 (joint model, data from: https://github.com/ArneBinder/pie-document-level/issues/390#issuecomment-2759888004)
|
377 |
+
python src/analysis/show_inference_params_on_quality_and_throughput.py \
|
378 |
+
--job-return-value-path-test logs/prediction/multiruns/default/2025-03-28_01-34-07/job_return_value.json \
|
379 |
+
--job-return-value-path-val logs/prediction/multiruns/default/2025-03-28_02-57-00/job_return_value.json \
|
380 |
+
--metric-data-path-test data/evaluation/argumentation_structure/inference_joint_test.json \
|
381 |
+
--metric-data-path-val data/evaluation/argumentation_structure/inference_joint_validation.json \
|
382 |
+
--metric-filters task=are discont_comp=true split=val \
|
383 |
+
--plot-method scatter
|
384 |
+
"""
|
385 |
+
|
386 |
+
parser = argparse.ArgumentParser()
|
387 |
+
parser.add_argument(
|
388 |
+
"--job-return-value-path-test",
|
389 |
+
type=str,
|
390 |
+
nargs="+",
|
391 |
+
required=True,
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--job-return-value-path-val",
|
395 |
+
type=str,
|
396 |
+
nargs="+",
|
397 |
+
required=True,
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--metric-data-path-test",
|
401 |
+
type=str,
|
402 |
+
nargs="+",
|
403 |
+
required=True,
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--metric-data-path-val",
|
407 |
+
type=str,
|
408 |
+
nargs="+",
|
409 |
+
required=True,
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--job-id-prefixes",
|
413 |
+
type=str,
|
414 |
+
nargs="*",
|
415 |
+
default=None,
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--plot-method",
|
419 |
+
type=str,
|
420 |
+
default="line",
|
421 |
+
choices=["scatter", "line"],
|
422 |
+
help="Plot method to use (default: line)",
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--color-column",
|
426 |
+
type=str,
|
427 |
+
default=None,
|
428 |
+
help="Column to use for colour coding (default: None)",
|
429 |
+
)
|
430 |
+
parser.add_argument(
|
431 |
+
"--metric-filters",
|
432 |
+
type=str,
|
433 |
+
nargs="*",
|
434 |
+
default=None,
|
435 |
+
help="Filters to apply to the metric data in the format 'key=value'",
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--index-filters",
|
439 |
+
type=str,
|
440 |
+
nargs="*",
|
441 |
+
default=None,
|
442 |
+
help="Filters to apply to the index data in the format 'key=value'",
|
443 |
+
)
|
444 |
+
parser.add_argument(
|
445 |
+
"--index-blacklist",
|
446 |
+
type=str,
|
447 |
+
nargs="*",
|
448 |
+
default=None,
|
449 |
+
help="Blacklist to apply to the index data in the format 'key=value'",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--columns",
|
453 |
+
type=str,
|
454 |
+
nargs="*",
|
455 |
+
default=None,
|
456 |
+
help="Columns to plot (default: all)",
|
457 |
+
)
|
458 |
+
parser.add_argument(
|
459 |
+
"--pareto-front",
|
460 |
+
action="store_true",
|
461 |
+
help="Whether to show only the pareto front",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--show-as",
|
465 |
+
type=str,
|
466 |
+
default="figure",
|
467 |
+
choices=["figure", "markdown", "json"],
|
468 |
+
help="How to show the results (default: figure)",
|
469 |
+
)
|
470 |
+
|
471 |
+
kwargs = vars(parser.parse_args())
|
472 |
+
|
473 |
+
main(
|
474 |
+
char_total_test=383154,
|
475 |
+
char_total_val=182794,
|
476 |
+
label_mapping={
|
477 |
+
"max_argument_distance": "Max. Argument Distance",
|
478 |
+
"max_length": "Max. Length",
|
479 |
+
"num_beams": "Num. Beams",
|
480 |
+
"task=are": "ARE",
|
481 |
+
"discont_comp=true": "Discont. Comp.",
|
482 |
+
"split=val": "Validation Split",
|
483 |
+
},
|
484 |
+
**kwargs,
|
485 |
+
)
|
src/datamodules/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .
|
|
|
1 |
+
from .datamodule_with_sampler import PieDataModuleWithSampler
|
src/datamodules/datamodule_with_sampler.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
from pytorch_ie import PieDataModule
|
5 |
+
from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset
|
6 |
+
from torch.utils.data import DataLoader, Sampler
|
7 |
+
|
8 |
+
from .components.sampler import ImbalancedDatasetSampler
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class PieDataModuleWithSampler(PieDataModule):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
train_sampler: Optional[str] = None,
|
18 |
+
dont_shuffle_train: bool = False,
|
19 |
+
**kwargs,
|
20 |
+
) -> None:
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
self.train_sampler_name = train_sampler
|
24 |
+
self.dont_shuffle_train = dont_shuffle_train
|
25 |
+
|
26 |
+
def get_train_sampler(
|
27 |
+
self,
|
28 |
+
dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset],
|
29 |
+
) -> Optional[Sampler]:
|
30 |
+
if self.train_sampler_name is None:
|
31 |
+
return None
|
32 |
+
elif self.train_sampler_name == "imbalanced_dataset":
|
33 |
+
# for now, this work only with targets that have a single entry
|
34 |
+
return ImbalancedDatasetSampler(
|
35 |
+
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
raise ValueError(f"unknown sampler name: {self.train_sampler_name}")
|
39 |
+
|
40 |
+
def train_dataloader(self) -> DataLoader:
|
41 |
+
ds = self.data_split(self.train_split)
|
42 |
+
sampler = self.get_train_sampler(dataset=ds)
|
43 |
+
# don't shuffle if we explicitly set dont_shuffle_train,
|
44 |
+
# streamed datasets or if we use a sampler or
|
45 |
+
shuffle = not (
|
46 |
+
self.dont_shuffle_train
|
47 |
+
or isinstance(ds, IterableTaskEncodingDataset)
|
48 |
+
or sampler is not None
|
49 |
+
)
|
50 |
+
|
51 |
+
if not shuffle:
|
52 |
+
logger.warning("not shuffling train dataloader")
|
53 |
+
return DataLoader(
|
54 |
+
dataset=ds,
|
55 |
+
sampler=sampler,
|
56 |
+
collate_fn=self.taskmodule.collate,
|
57 |
+
shuffle=shuffle,
|
58 |
+
**self.dataloader_kwargs,
|
59 |
+
)
|
src/dataset/processing.py
CHANGED
@@ -1,9 +1,16 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
from pie_datasets import Dataset, DatasetDict
|
|
|
4 |
from pytorch_ie import Document
|
|
|
|
|
5 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
|
6 |
|
|
|
|
|
7 |
|
8 |
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
|
9 |
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
|
@@ -11,7 +18,7 @@ def apply_func_to_splits(
|
|
11 |
dataset: DatasetDict,
|
12 |
function: Union[str, Callable],
|
13 |
result_document_type: Type[Document],
|
14 |
-
**kwargs
|
15 |
):
|
16 |
resolved_func = resolve_target(function)
|
17 |
resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
|
@@ -23,7 +30,85 @@ def apply_func_to_splits(
|
|
23 |
batched=True,
|
24 |
batch_size=len(split),
|
25 |
result_document_type=resolved_document_type,
|
26 |
-
**kwargs
|
27 |
)
|
28 |
result_dict[split_name] = converted_dataset
|
29 |
return DatasetDict(result_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union
|
4 |
|
5 |
from pie_datasets import Dataset, DatasetDict
|
6 |
+
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
7 |
from pytorch_ie import Document
|
8 |
+
from pytorch_ie.annotations import BinaryRelation, Span
|
9 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
|
10 |
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
|
11 |
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
|
15 |
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
|
16 |
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
|
|
|
18 |
dataset: DatasetDict,
|
19 |
function: Union[str, Callable],
|
20 |
result_document_type: Type[Document],
|
21 |
+
**kwargs,
|
22 |
):
|
23 |
resolved_func = resolve_target(function)
|
24 |
resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
|
|
|
30 |
batched=True,
|
31 |
batch_size=len(split),
|
32 |
result_document_type=resolved_document_type,
|
33 |
+
**kwargs,
|
34 |
)
|
35 |
result_dict[split_name] = converted_dataset
|
36 |
return DatasetDict(result_dict)
|
37 |
+
|
38 |
+
|
39 |
+
S = TypeVar("S", bound=Span)
|
40 |
+
|
41 |
+
|
42 |
+
def shift_span(span: S, offset: int) -> S:
|
43 |
+
"""Shift the start and end of a span by a given offset."""
|
44 |
+
return span.copy(start=span.start + offset, end=span.end + offset)
|
45 |
+
|
46 |
+
|
47 |
+
D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations)
|
48 |
+
|
49 |
+
|
50 |
+
def add_predicted_semantically_same_relations_to_document(
|
51 |
+
document: D,
|
52 |
+
doc_id2docs_with_predictions: Dict[
|
53 |
+
str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
54 |
+
],
|
55 |
+
relation_label: str,
|
56 |
+
argument_label_blacklist: Optional[List[str]] = None,
|
57 |
+
verbose: bool = False,
|
58 |
+
) -> D:
|
59 |
+
|
60 |
+
# create lookup for detached versions of the spans (attached span != detached span even if they are the same)
|
61 |
+
span2span = {span.copy(): span for span in document.labeled_spans}
|
62 |
+
for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []):
|
63 |
+
offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"]
|
64 |
+
offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"]
|
65 |
+
for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions:
|
66 |
+
head = shift_span(coref_rel.head, offset=offset)
|
67 |
+
if head not in span2span:
|
68 |
+
if verbose:
|
69 |
+
logger.warning(f"doc_id={document.id}: Head span {head} not found.")
|
70 |
+
continue
|
71 |
+
tail = shift_span(coref_rel.tail, offset=offset_pair)
|
72 |
+
if tail not in span2span:
|
73 |
+
if verbose:
|
74 |
+
logger.warning(f"doc_id={document.id}: Tail span {tail} not found.")
|
75 |
+
continue
|
76 |
+
if argument_label_blacklist is not None and (
|
77 |
+
span2span[head].label in argument_label_blacklist
|
78 |
+
or span2span[tail].label in argument_label_blacklist
|
79 |
+
):
|
80 |
+
continue
|
81 |
+
new_rel = BinaryRelation(
|
82 |
+
head=span2span[head],
|
83 |
+
tail=span2span[tail],
|
84 |
+
label=relation_label,
|
85 |
+
score=coref_rel.score,
|
86 |
+
)
|
87 |
+
document.binary_relations.predictions.append(new_rel)
|
88 |
+
return document
|
89 |
+
|
90 |
+
|
91 |
+
def integrate_coref_predictions_from_text_pair_documents(
|
92 |
+
dataset: DatasetDict, data_dir: str, **kwargs
|
93 |
+
) -> DatasetDict:
|
94 |
+
|
95 |
+
dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir)
|
96 |
+
|
97 |
+
for split_name in dataset.keys():
|
98 |
+
ds_with_predictions = dataset_with_predictions[split_name]
|
99 |
+
original_doc_id2docs = defaultdict(list)
|
100 |
+
for doc in ds_with_predictions:
|
101 |
+
original_doc_id = doc.metadata["original_doc_id"]
|
102 |
+
if original_doc_id != doc.metadata["original_doc_id_pair"]:
|
103 |
+
raise ValueError(
|
104 |
+
f"Original document IDs do not match: "
|
105 |
+
f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. "
|
106 |
+
f"Cross-document coref is not supported."
|
107 |
+
)
|
108 |
+
original_doc_id2docs[original_doc_id].append(doc)
|
109 |
+
|
110 |
+
dataset[split_name] = dataset[split_name].map(
|
111 |
+
function=add_predicted_semantically_same_relations_to_document,
|
112 |
+
fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs),
|
113 |
+
)
|
114 |
+
return dataset
|
src/demo/annotation_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import json
|
2 |
import logging
|
3 |
-
from typing import Iterable, Optional, Sequence
|
4 |
|
5 |
import gradio as gr
|
6 |
from hydra.utils import instantiate
|
@@ -41,59 +41,6 @@ def get_merger() -> SpansViaRelationMerger:
|
|
41 |
)
|
42 |
|
43 |
|
44 |
-
def annotate_document(
|
45 |
-
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
46 |
-
argumentation_model: Pipeline,
|
47 |
-
handle_parts_of_same: bool = False,
|
48 |
-
) -> Union[
|
49 |
-
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
50 |
-
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
51 |
-
]:
|
52 |
-
"""Annotate a document with the provided pipeline.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
document: The document to annotate.
|
56 |
-
argumentation_model: The pipeline to use for annotation.
|
57 |
-
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
58 |
-
"""
|
59 |
-
|
60 |
-
# execute prediction pipeline
|
61 |
-
result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
|
62 |
-
document, inplace=True
|
63 |
-
)
|
64 |
-
|
65 |
-
if handle_parts_of_same:
|
66 |
-
merger = get_merger()
|
67 |
-
result = merger(result)
|
68 |
-
|
69 |
-
return result
|
70 |
-
|
71 |
-
|
72 |
-
def annotate_documents(
|
73 |
-
documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
|
74 |
-
argumentation_model: Pipeline,
|
75 |
-
handle_parts_of_same: bool = False,
|
76 |
-
) -> Union[
|
77 |
-
Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
|
78 |
-
Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
|
79 |
-
]:
|
80 |
-
"""Annotate a sequence of documents with the provided pipeline.
|
81 |
-
|
82 |
-
Args:
|
83 |
-
documents: The documents to annotate.
|
84 |
-
argumentation_model: The pipeline to use for annotation.
|
85 |
-
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
86 |
-
"""
|
87 |
-
# execute prediction pipeline
|
88 |
-
result = argumentation_model(documents, inplace=True)
|
89 |
-
|
90 |
-
if handle_parts_of_same:
|
91 |
-
merger = get_merger()
|
92 |
-
result = [merger(document) for document in result]
|
93 |
-
|
94 |
-
return result
|
95 |
-
|
96 |
-
|
97 |
def create_document(
|
98 |
text: str, doc_id: str, split_regex: Optional[str] = None
|
99 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
@@ -143,14 +90,17 @@ def create_documents(
|
|
143 |
]
|
144 |
|
145 |
|
146 |
-
def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
|
147 |
try:
|
148 |
config = parse_config(config_str, format="yaml")
|
|
|
|
|
|
|
149 |
|
150 |
# for PIE AutoPipeline, we need to handle the revision separately for
|
151 |
# the taskmodule and the model
|
152 |
if (
|
153 |
-
config.get("_target_"
|
154 |
and "revision" in config
|
155 |
):
|
156 |
revision = config.pop("revision")
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
+
from typing import Iterable, Optional, Sequence
|
4 |
|
5 |
import gradio as gr
|
6 |
from hydra.utils import instantiate
|
|
|
41 |
)
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def create_document(
|
45 |
text: str, doc_id: str, split_regex: Optional[str] = None
|
46 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
|
|
90 |
]
|
91 |
|
92 |
|
93 |
+
def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]:
|
94 |
try:
|
95 |
config = parse_config(config_str, format="yaml")
|
96 |
+
if config is None or config == {}:
|
97 |
+
gr.Warning("Empty argumentation model config provided. No model loaded.")
|
98 |
+
return None
|
99 |
|
100 |
# for PIE AutoPipeline, we need to handle the revision separately for
|
101 |
# the taskmodule and the model
|
102 |
if (
|
103 |
+
config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained")
|
104 |
and "revision" in config
|
105 |
):
|
106 |
revision = config.pop("revision")
|
src/demo/backend_utils.py
CHANGED
@@ -12,10 +12,11 @@ from pie_datasets import Dataset, IterableDataset, load_dataset
|
|
12 |
from pytorch_ie import Pipeline
|
13 |
from pytorch_ie.documents import (
|
14 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
|
|
15 |
)
|
16 |
from tqdm import tqdm
|
17 |
|
18 |
-
from src.demo.annotation_utils import
|
19 |
from src.demo.data_utils import load_text_from_arxiv
|
20 |
from src.demo.rendering_utils import (
|
21 |
RENDER_WITH_DISPLACY,
|
@@ -54,7 +55,7 @@ def add_annotated_pie_documents(
|
|
54 |
def process_texts(
|
55 |
texts: Iterable[str],
|
56 |
doc_ids: Iterable[str],
|
57 |
-
argumentation_model: Pipeline,
|
58 |
retriever: DocumentAwareSpanRetriever,
|
59 |
split_regex_escaped: Optional[str],
|
60 |
handle_parts_of_same: bool = False,
|
@@ -68,13 +69,21 @@ def process_texts(
|
|
68 |
doc_ids=doc_ids,
|
69 |
split_regex=split_regex_escaped,
|
70 |
)
|
71 |
-
if
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
add_annotated_pie_documents(
|
79 |
retriever=retriever,
|
80 |
pie_documents=pie_documents,
|
@@ -93,12 +102,41 @@ def add_annotated_pie_documents_from_dataset(
|
|
93 |
dataset = load_dataset(**load_dataset_kwargs)
|
94 |
if not isinstance(dataset, (Dataset, IterableDataset)):
|
95 |
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
)
|
|
|
99 |
add_annotated_pie_documents(
|
100 |
retriever=retriever,
|
101 |
-
pie_documents=
|
102 |
use_predicted_annotations=False,
|
103 |
verbose=verbose,
|
104 |
)
|
|
|
12 |
from pytorch_ie import Pipeline
|
13 |
from pytorch_ie.documents import (
|
14 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
15 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
16 |
)
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
+
from src.demo.annotation_utils import create_documents, get_merger
|
20 |
from src.demo.data_utils import load_text_from_arxiv
|
21 |
from src.demo.rendering_utils import (
|
22 |
RENDER_WITH_DISPLACY,
|
|
|
55 |
def process_texts(
|
56 |
texts: Iterable[str],
|
57 |
doc_ids: Iterable[str],
|
58 |
+
argumentation_model: Optional[Pipeline],
|
59 |
retriever: DocumentAwareSpanRetriever,
|
60 |
split_regex_escaped: Optional[str],
|
61 |
handle_parts_of_same: bool = False,
|
|
|
69 |
doc_ids=doc_ids,
|
70 |
split_regex=split_regex_escaped,
|
71 |
)
|
72 |
+
if argumentation_model is not None:
|
73 |
+
if verbose:
|
74 |
+
gr.Info(f"Annotate {len(pie_documents)} documents...")
|
75 |
+
pie_documents = argumentation_model(pie_documents, inplace=True)
|
76 |
+
else:
|
77 |
+
gr.Warning(
|
78 |
+
"Annotation is disabled (no model was loaded). No annotations will be added to the documents."
|
79 |
+
)
|
80 |
+
|
81 |
+
# this needs to be done also if the documents are not annotated because
|
82 |
+
# it adjusts the document type
|
83 |
+
if handle_parts_of_same:
|
84 |
+
merger = get_merger()
|
85 |
+
pie_documents = [merger(document) for document in pie_documents]
|
86 |
+
|
87 |
add_annotated_pie_documents(
|
88 |
retriever=retriever,
|
89 |
pie_documents=pie_documents,
|
|
|
102 |
dataset = load_dataset(**load_dataset_kwargs)
|
103 |
if not isinstance(dataset, (Dataset, IterableDataset)):
|
104 |
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
|
105 |
+
try:
|
106 |
+
dataset_converted = dataset.to_document_type(
|
107 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
108 |
+
)
|
109 |
+
except ValueError:
|
110 |
+
gr.Warning(
|
111 |
+
"The dataset does not seem to have registered converter to create multi-spans. "
|
112 |
+
"Try to Load as single-spans and to convert to multi-spans manually ..."
|
113 |
+
)
|
114 |
+
dataset_converted_single_span = dataset.to_document_type(
|
115 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
116 |
+
)
|
117 |
+
merger = get_merger()
|
118 |
+
dataset_converted = dataset_converted_single_span.map(
|
119 |
+
merger,
|
120 |
+
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
121 |
+
)
|
122 |
+
|
123 |
+
def _clear_metadata(
|
124 |
+
doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
125 |
+
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions:
|
126 |
+
result = doc.copy()
|
127 |
+
result.metadata = dict()
|
128 |
+
return result
|
129 |
+
|
130 |
+
# adding documents with different metadata format to the retriever breaks it,
|
131 |
+
# so we clear the metadata field beforehand
|
132 |
+
dataset_converted_without_metadata = dataset_converted.map(
|
133 |
+
_clear_metadata,
|
134 |
+
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
135 |
)
|
136 |
+
|
137 |
add_annotated_pie_documents(
|
138 |
retriever=retriever,
|
139 |
+
pie_documents=dataset_converted_without_metadata,
|
140 |
use_predicted_annotations=False,
|
141 |
verbose=verbose,
|
142 |
)
|
src/demo/retrieve_and_dump_all_relevant.py
CHANGED
@@ -10,21 +10,24 @@ root = pyrootutils.setup_root(
|
|
10 |
import argparse
|
11 |
import logging
|
12 |
import os
|
13 |
-
from typing import Dict, List, Optional, Tuple
|
14 |
|
15 |
import pandas as pd
|
16 |
from pie_datasets import Dataset, DatasetDict
|
17 |
from pytorch_ie import Annotation
|
18 |
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
|
19 |
|
20 |
-
from document.types import (
|
21 |
-
RelatedRelation,
|
22 |
-
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
23 |
-
)
|
24 |
from src.demo.retriever_utils import (
|
25 |
retrieve_all_relevant_spans,
|
26 |
retrieve_all_relevant_spans_for_all_documents,
|
|
|
|
|
27 |
retrieve_relevant_spans,
|
|
|
|
|
|
|
|
|
|
|
28 |
)
|
29 |
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
30 |
|
@@ -131,14 +134,17 @@ def add_result_to_gold_data(
|
|
131 |
base_annotation_mapping=base_annotation_mapping,
|
132 |
)
|
133 |
)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
-
)
|
142 |
doc_and_span_id2annotation.update(
|
143 |
get_doc_and_span_id2annotation_mapping(
|
144 |
span_ids=result["query_span_id"],
|
@@ -159,38 +165,51 @@ def add_result_to_gold_data(
|
|
159 |
(row.query_doc_id, row.query_span_id)
|
160 |
]
|
161 |
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
|
162 |
-
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
|
163 |
if doc_id != query_doc_id:
|
164 |
raise ValueError("doc_id and query_doc_id must be the same")
|
165 |
-
if doc_id != doc_id2:
|
166 |
-
raise ValueError("doc_id and ref_doc_id must be the same")
|
167 |
doc = doc_id2doc[doc_id]
|
168 |
-
|
169 |
-
|
170 |
-
)
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
else:
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
relation=base_rel,
|
183 |
-
label=related_rel_label,
|
184 |
-
score=link_rel.score * base_rel.score,
|
185 |
-
)
|
186 |
-
doc.related_relations.predictions.append(related_rel)
|
187 |
|
188 |
dataset = Dataset.from_documents(list(doc_id2doc.values()))
|
189 |
dataset_dict = DatasetDict({split: dataset})
|
190 |
if not os.path.exists(dataset_out_dir):
|
191 |
os.makedirs(dataset_out_dir, exist_ok=True)
|
192 |
|
193 |
-
dataset_dict.to_json(dataset_out_dir)
|
194 |
|
195 |
|
196 |
if __name__ == "__main__":
|
@@ -216,6 +235,13 @@ if __name__ == "__main__":
|
|
216 |
type=str,
|
217 |
required=True,
|
218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
parser.add_argument(
|
220 |
"--query_doc_id",
|
221 |
type=str,
|
@@ -282,6 +308,24 @@ if __name__ == "__main__":
|
|
282 |
logger.info(f"loading data from {args.data_path}...")
|
283 |
retriever.load_from_disc(args.data_path)
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
286 |
if args.doc_id_whitelist is not None:
|
287 |
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
|
@@ -293,7 +337,7 @@ if __name__ == "__main__":
|
|
293 |
all_spans_for_all_documents = None
|
294 |
for doc_id_pair in args.query_target_doc_id_pairs:
|
295 |
query_doc_id, target_doc_id = doc_id_pair.split(":")
|
296 |
-
current_result =
|
297 |
retriever=retriever,
|
298 |
query_doc_id=query_doc_id,
|
299 |
doc_id_whitelist=[target_doc_id],
|
@@ -319,16 +363,16 @@ if __name__ == "__main__":
|
|
319 |
|
320 |
elif args.query_span_id is not None:
|
321 |
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
322 |
-
all_spans_for_all_documents =
|
323 |
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
324 |
)
|
325 |
elif args.query_doc_id is not None:
|
326 |
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
|
327 |
-
all_spans_for_all_documents =
|
328 |
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
|
329 |
)
|
330 |
else:
|
331 |
-
all_spans_for_all_documents =
|
332 |
retriever=retriever, **search_kwargs
|
333 |
)
|
334 |
|
|
|
10 |
import argparse
|
11 |
import logging
|
12 |
import os
|
13 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
14 |
|
15 |
import pandas as pd
|
16 |
from pie_datasets import Dataset, DatasetDict
|
17 |
from pytorch_ie import Annotation
|
18 |
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
|
19 |
|
|
|
|
|
|
|
|
|
20 |
from src.demo.retriever_utils import (
|
21 |
retrieve_all_relevant_spans,
|
22 |
retrieve_all_relevant_spans_for_all_documents,
|
23 |
+
retrieve_all_similar_spans,
|
24 |
+
retrieve_all_similar_spans_for_all_documents,
|
25 |
retrieve_relevant_spans,
|
26 |
+
retrieve_similar_spans,
|
27 |
+
)
|
28 |
+
from src.document.types import (
|
29 |
+
RelatedRelation,
|
30 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
31 |
)
|
32 |
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
33 |
|
|
|
134 |
base_annotation_mapping=base_annotation_mapping,
|
135 |
)
|
136 |
)
|
137 |
+
# only when we process relevant span retriever results, we have a ref_span_id
|
138 |
+
# (for similar span retriever results, we only have query_span_id)
|
139 |
+
if "ref_span_id" in result.columns:
|
140 |
+
doc_and_span_id2annotation.update(
|
141 |
+
get_doc_and_span_id2annotation_mapping(
|
142 |
+
span_ids=result["ref_span_id"],
|
143 |
+
doc_ids=result["doc_id"],
|
144 |
+
retriever=retriever,
|
145 |
+
base_annotation_mapping=base_annotation_mapping,
|
146 |
+
)
|
147 |
)
|
|
|
148 |
doc_and_span_id2annotation.update(
|
149 |
get_doc_and_span_id2annotation_mapping(
|
150 |
span_ids=result["query_span_id"],
|
|
|
165 |
(row.query_doc_id, row.query_span_id)
|
166 |
]
|
167 |
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
|
|
|
168 |
if doc_id != query_doc_id:
|
169 |
raise ValueError("doc_id and query_doc_id must be the same")
|
|
|
|
|
170 |
doc = doc_id2doc[doc_id]
|
171 |
+
|
172 |
+
# if we have a reference span, we need to construct the related relation
|
173 |
+
if hasattr(row, "ref_span_id"):
|
174 |
+
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
|
175 |
+
if doc_id != doc_id2:
|
176 |
+
raise ValueError("doc_id and ref_doc_id must be the same")
|
177 |
+
|
178 |
+
# create a link relation between the query span and the reference span
|
179 |
+
link_rel = BinaryRelation(
|
180 |
+
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
|
181 |
+
)
|
182 |
+
doc.binary_relations.predictions.append(link_rel)
|
183 |
+
|
184 |
+
head_and_tail2relation = doc_id2head_tail2relation[doc_id]
|
185 |
+
related_rel_label = row.type
|
186 |
+
if related_rel_label.endswith(reversed_relation_suffix):
|
187 |
+
base_rel = head_and_tail2relation[(span, ref_span)]
|
188 |
+
else:
|
189 |
+
base_rel = head_and_tail2relation[(ref_span, span)]
|
190 |
+
related_rel = RelatedRelation(
|
191 |
+
head=query_span,
|
192 |
+
tail=span,
|
193 |
+
link_relation=link_rel,
|
194 |
+
relation=base_rel,
|
195 |
+
label=related_rel_label,
|
196 |
+
score=link_rel.score * base_rel.score,
|
197 |
+
)
|
198 |
+
doc.related_relations.predictions.append(related_rel)
|
199 |
+
# otherwise, we just ...
|
200 |
else:
|
201 |
+
# ... create a link relation between the query span and returned span
|
202 |
+
link_rel = BinaryRelation(
|
203 |
+
head=query_span, tail=span, label=link_relation_label, score=row.sim_score
|
204 |
+
)
|
205 |
+
doc.binary_relations.predictions.append(link_rel)
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
dataset = Dataset.from_documents(list(doc_id2doc.values()))
|
208 |
dataset_dict = DatasetDict({split: dataset})
|
209 |
if not os.path.exists(dataset_out_dir):
|
210 |
os.makedirs(dataset_out_dir, exist_ok=True)
|
211 |
|
212 |
+
dataset_dict.to_json(dataset_out_dir, mode="w")
|
213 |
|
214 |
|
215 |
if __name__ == "__main__":
|
|
|
235 |
type=str,
|
236 |
required=True,
|
237 |
)
|
238 |
+
parser.add_argument(
|
239 |
+
"-v",
|
240 |
+
"--variant",
|
241 |
+
choices=["relevant", "similar"],
|
242 |
+
default="relevant",
|
243 |
+
help="Variant of the retriever to use: 'relevant' for relevant spans, 'similar' for similar spans.",
|
244 |
+
)
|
245 |
parser.add_argument(
|
246 |
"--query_doc_id",
|
247 |
type=str,
|
|
|
308 |
logger.info(f"loading data from {args.data_path}...")
|
309 |
retriever.load_from_disc(args.data_path)
|
310 |
|
311 |
+
methods: Dict[str, Callable]
|
312 |
+
if args.variant == "relevant":
|
313 |
+
logger.info("using *relevant* span retriever methods")
|
314 |
+
methods = {
|
315 |
+
"retrieve_all_spans": retrieve_all_relevant_spans,
|
316 |
+
"retrieve_spans": retrieve_relevant_spans,
|
317 |
+
"retrieve_all_spans_for_all_documents": retrieve_all_relevant_spans_for_all_documents,
|
318 |
+
}
|
319 |
+
elif args.variant == "similar":
|
320 |
+
logger.info("using *similar* span retriever methods")
|
321 |
+
methods = {
|
322 |
+
"retrieve_all_spans": retrieve_all_similar_spans,
|
323 |
+
"retrieve_spans": retrieve_similar_spans,
|
324 |
+
"retrieve_all_spans_for_all_documents": retrieve_all_similar_spans_for_all_documents,
|
325 |
+
}
|
326 |
+
else:
|
327 |
+
raise ValueError(f"unknown method: {args.variant}")
|
328 |
+
|
329 |
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
330 |
if args.doc_id_whitelist is not None:
|
331 |
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
|
|
|
337 |
all_spans_for_all_documents = None
|
338 |
for doc_id_pair in args.query_target_doc_id_pairs:
|
339 |
query_doc_id, target_doc_id = doc_id_pair.split(":")
|
340 |
+
current_result = methods["retrieve_all_spans"](
|
341 |
retriever=retriever,
|
342 |
query_doc_id=query_doc_id,
|
343 |
doc_id_whitelist=[target_doc_id],
|
|
|
363 |
|
364 |
elif args.query_span_id is not None:
|
365 |
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
366 |
+
all_spans_for_all_documents = methods["retrieve_spans"](
|
367 |
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
368 |
)
|
369 |
elif args.query_doc_id is not None:
|
370 |
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
|
371 |
+
all_spans_for_all_documents = methods["retrieve_all_spans"](
|
372 |
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
|
373 |
)
|
374 |
else:
|
375 |
+
all_spans_for_all_documents = methods["retrieve_all_spans_for_all_documents"](
|
376 |
retriever=retriever, **search_kwargs
|
377 |
)
|
378 |
|
src/document/processing.py
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import logging
|
4 |
from collections import defaultdict
|
5 |
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
|
6 |
|
7 |
from pie_modules.utils.span import have_overlap
|
8 |
from pytorch_ie import AnnotationLayer
|
9 |
-
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
10 |
from pytorch_ie.core import Document
|
11 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
|
|
12 |
|
13 |
from src.document.types import (
|
14 |
RelatedRelation,
|
15 |
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
16 |
)
|
17 |
from src.utils import distance, distance_slices
|
|
|
18 |
from src.utils.span_utils import get_overlap_len
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
@@ -123,6 +126,69 @@ def remove_partitions_by_labels(
|
|
123 |
D_text = TypeVar("D_text", bound=Document)
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def replace_substrings_in_text(
|
127 |
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
|
128 |
) -> D_text:
|
@@ -512,3 +578,236 @@ def add_related_relations_from_binary_relations(
|
|
512 |
)
|
513 |
|
514 |
return document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import itertools
|
4 |
import logging
|
5 |
from collections import defaultdict
|
6 |
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
|
7 |
|
8 |
from pie_modules.utils.span import have_overlap
|
9 |
from pytorch_ie import AnnotationLayer
|
10 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
11 |
from pytorch_ie.core import Document
|
12 |
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
13 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
|
14 |
|
15 |
from src.document.types import (
|
16 |
RelatedRelation,
|
17 |
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
|
18 |
)
|
19 |
from src.utils import distance, distance_slices
|
20 |
+
from src.utils.graph_utils import get_connected_components
|
21 |
from src.utils.span_utils import get_overlap_len
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
|
|
126 |
D_text = TypeVar("D_text", bound=Document)
|
127 |
|
128 |
|
129 |
+
def remove_annotations_by_label(
|
130 |
+
document: D, layer2label_blacklist: Dict[str, List[str]], verbose: bool = False
|
131 |
+
) -> D:
|
132 |
+
"""Remove annotations with labels in the blacklist from a document.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
document: The document to process.
|
136 |
+
layer2label_blacklist: A mapping from layer names to lists of labels to remove.
|
137 |
+
verbose: Whether to print number of removed annotations.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
The processed document.
|
141 |
+
"""
|
142 |
+
|
143 |
+
result = document.copy(with_annotations=False)
|
144 |
+
override_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
|
145 |
+
removed_annotations: Dict[str, Set[int]] = defaultdict(set)
|
146 |
+
for layer_name, labels in layer2label_blacklist.items():
|
147 |
+
# process gold annotations and predictions
|
148 |
+
for src_layer, tgt_layer in [
|
149 |
+
(document[layer_name], result[layer_name]),
|
150 |
+
(document[layer_name].predictions, result[layer_name].predictions),
|
151 |
+
]:
|
152 |
+
current_override_annotations = dict()
|
153 |
+
current_removed_annotations = set()
|
154 |
+
for annotation in src_layer:
|
155 |
+
label = getattr(annotation, "label")
|
156 |
+
if label is None:
|
157 |
+
raise ValueError(
|
158 |
+
f"Annotation {annotation} has no label. Please check the annotation type."
|
159 |
+
)
|
160 |
+
if label not in labels:
|
161 |
+
current_override_annotations[annotation._id] = annotation.copy()
|
162 |
+
else:
|
163 |
+
current_removed_annotations.add(annotation._id)
|
164 |
+
tgt_layer.extend(current_override_annotations.values())
|
165 |
+
|
166 |
+
override_annotations[layer_name].update(current_override_annotations)
|
167 |
+
removed_annotations[layer_name].update(current_removed_annotations)
|
168 |
+
if verbose:
|
169 |
+
num_removed = {
|
170 |
+
layer_name: len(removed_ids) for layer_name, removed_ids in removed_annotations.items()
|
171 |
+
}
|
172 |
+
if len(num_removed) > 0:
|
173 |
+
num_total = {
|
174 |
+
layer_name: len(kept_ids) + num_removed[layer_name]
|
175 |
+
for layer_name, kept_ids in override_annotations.items()
|
176 |
+
}
|
177 |
+
logger.warning(
|
178 |
+
f"doc.id={document.id}: Removed {num_removed} (total: {num_total}) "
|
179 |
+
f"annotations with label blacklists {layer2label_blacklist}"
|
180 |
+
)
|
181 |
+
|
182 |
+
result.add_all_annotations_from_other(
|
183 |
+
other=document,
|
184 |
+
removed_annotations=removed_annotations,
|
185 |
+
override_annotations=override_annotations,
|
186 |
+
strict=False,
|
187 |
+
verbose=False,
|
188 |
+
)
|
189 |
+
return result
|
190 |
+
|
191 |
+
|
192 |
def replace_substrings_in_text(
|
193 |
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
|
194 |
) -> D_text:
|
|
|
578 |
)
|
579 |
|
580 |
return document
|
581 |
+
|
582 |
+
|
583 |
+
T = TypeVar("T", bound=TextDocumentWithLabeledSpansAndBinaryRelations)
|
584 |
+
|
585 |
+
|
586 |
+
def remove_discontinuous_spans(
|
587 |
+
document: T,
|
588 |
+
parts_of_same_relation: str,
|
589 |
+
verbose: bool = False,
|
590 |
+
) -> T:
|
591 |
+
"""
|
592 |
+
Remove discontinuous spans from a document.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
document: The document to process.
|
596 |
+
parts_of_same_relation: The name of the relation that indicates linked spans.
|
597 |
+
verbose: Whether to print debug information.
|
598 |
+
|
599 |
+
Returns:
|
600 |
+
The processed document.
|
601 |
+
"""
|
602 |
+
result = document.copy()
|
603 |
+
spans = result.labeled_spans.clear()
|
604 |
+
rels = result.binary_relations.clear()
|
605 |
+
|
606 |
+
segment_spans = set()
|
607 |
+
segment_rels = set()
|
608 |
+
# collect all spans that are linked
|
609 |
+
for rel in rels:
|
610 |
+
if rel.label == parts_of_same_relation:
|
611 |
+
segment_spans.add(rel.head)
|
612 |
+
segment_spans.add(rel.tail)
|
613 |
+
segment_rels.add(rel)
|
614 |
+
|
615 |
+
for span in spans:
|
616 |
+
if span not in segment_spans:
|
617 |
+
result.labeled_spans.append(span)
|
618 |
+
|
619 |
+
other_rels_dropped = set()
|
620 |
+
for rel in rels:
|
621 |
+
if rel not in segment_rels:
|
622 |
+
if rel.head not in segment_spans and rel.tail not in segment_spans:
|
623 |
+
result.binary_relations.append(rel)
|
624 |
+
else:
|
625 |
+
other_rels_dropped.add(rel)
|
626 |
+
|
627 |
+
if verbose:
|
628 |
+
if len(segment_rels) > 0:
|
629 |
+
logger.warning(
|
630 |
+
f"doc={document.id}: Dropped {len(segment_rels)} segment rels "
|
631 |
+
f"and {len(other_rels_dropped)} other rels "
|
632 |
+
f"({round((len(document.binary_relations) - len(result.binary_relations)) * 100 / len(document.binary_relations), 1)}% "
|
633 |
+
f"of all relations dropped)"
|
634 |
+
)
|
635 |
+
return result
|
636 |
+
|
637 |
+
|
638 |
+
def close_clusters_transitively(
|
639 |
+
document: D, relation_layer: str, link_relation_label: str, verbose: bool = False
|
640 |
+
) -> D:
|
641 |
+
"""
|
642 |
+
Close clusters transitively by adding relations between all pairs of spans in the same cluster.
|
643 |
+
|
644 |
+
Args:
|
645 |
+
document: The document to process.
|
646 |
+
relation_layer: The name of the relation layer.
|
647 |
+
link_relation_label: The label of the link relation.
|
648 |
+
verbose: Whether to print debug information.
|
649 |
+
|
650 |
+
Returns:
|
651 |
+
The processed document.
|
652 |
+
"""
|
653 |
+
result = document.copy()
|
654 |
+
|
655 |
+
connected_components: List[List[Annotation]] = get_connected_components(
|
656 |
+
relations=result[relation_layer],
|
657 |
+
link_relation_label=link_relation_label,
|
658 |
+
add_singletons=False,
|
659 |
+
)
|
660 |
+
# detach from document
|
661 |
+
relations = result[relation_layer].clear()
|
662 |
+
# use set to speed up membership checks
|
663 |
+
relations_set = set(relations)
|
664 |
+
n_before = len(relations)
|
665 |
+
for cluster in connected_components:
|
666 |
+
for head, tail in itertools.combinations(sorted(cluster), 2):
|
667 |
+
rel = BinaryRelation(
|
668 |
+
head=head,
|
669 |
+
tail=tail,
|
670 |
+
label=link_relation_label,
|
671 |
+
)
|
672 |
+
rel_reversed = BinaryRelation(
|
673 |
+
head=tail,
|
674 |
+
tail=head,
|
675 |
+
label=link_relation_label,
|
676 |
+
)
|
677 |
+
if rel not in relations_set and rel_reversed not in relations_set:
|
678 |
+
# append to relations to keep the order
|
679 |
+
relations.append(rel)
|
680 |
+
relations_set.add(rel)
|
681 |
+
|
682 |
+
result[relation_layer].extend(relations)
|
683 |
+
if verbose:
|
684 |
+
num_added = len(relations) - n_before
|
685 |
+
if num_added > 0:
|
686 |
+
logger.warning(
|
687 |
+
f"doc.id={document.id}: added {num_added} relations to {relation_layer} layer"
|
688 |
+
)
|
689 |
+
|
690 |
+
return result
|
691 |
+
|
692 |
+
|
693 |
+
def get_ancestor_layers(children: Dict[str, Set[str]], layer: str) -> Set[str]:
|
694 |
+
"""
|
695 |
+
Get all ancestor layers of a given layer in the dependency graph.
|
696 |
+
|
697 |
+
Args:
|
698 |
+
children: A mapping from layers to their children layers.
|
699 |
+
layer: The layer for which to find ancestors.
|
700 |
+
|
701 |
+
Returns:
|
702 |
+
A set of ancestor layers.
|
703 |
+
"""
|
704 |
+
ancestors = set()
|
705 |
+
|
706 |
+
def _get_ancestors(current_layer: str):
|
707 |
+
for parent_layer, child_layers in children.items():
|
708 |
+
if current_layer in child_layers:
|
709 |
+
ancestors.add(parent_layer)
|
710 |
+
_get_ancestors(parent_layer)
|
711 |
+
|
712 |
+
_get_ancestors(layer)
|
713 |
+
# drop the _artificial_root
|
714 |
+
ancestors.discard("_artificial_root")
|
715 |
+
return ancestors
|
716 |
+
|
717 |
+
|
718 |
+
def remove_binary_relations_by_partition_labels(
|
719 |
+
document: D,
|
720 |
+
partition_layer: str,
|
721 |
+
relation_layer: str,
|
722 |
+
partition_label_whitelist: Optional[List[List[str]]] = None,
|
723 |
+
partition_label_blacklist: Optional[List[List[str]]] = None,
|
724 |
+
verbose: bool = False,
|
725 |
+
) -> D:
|
726 |
+
"""
|
727 |
+
Remove binary relations that are not between partitions with labels in the whitelist or
|
728 |
+
that are in the blacklist.
|
729 |
+
|
730 |
+
Args:
|
731 |
+
document: The document to process.
|
732 |
+
partition_layer: The name of the partition layer.
|
733 |
+
relation_layer: The name of the relation layer.
|
734 |
+
partition_label_whitelist: The list of head-tail label pairs to keep.
|
735 |
+
partition_label_blacklist: The list of head-tail label pairs to remove.
|
736 |
+
verbose: Whether to print the removed relations to console.
|
737 |
+
|
738 |
+
Returns:
|
739 |
+
The processed document.
|
740 |
+
"""
|
741 |
+
result = document.copy()
|
742 |
+
|
743 |
+
relation_annotation_layer = result[relation_layer]
|
744 |
+
# get all layers that target the relation layer
|
745 |
+
relation_dependent_layers = get_ancestor_layers(
|
746 |
+
children=result._annotation_graph, layer=relation_layer
|
747 |
+
)
|
748 |
+
# clear all layers that depend on the relation layer
|
749 |
+
for layer_name in relation_dependent_layers:
|
750 |
+
dependent_layer = result[layer_name]
|
751 |
+
gold_anns_cleared = dependent_layer.clear()
|
752 |
+
pred_anns_cleared = dependent_layer.predictions.clear()
|
753 |
+
if len(gold_anns_cleared) > 0 or len(pred_anns_cleared) > 0:
|
754 |
+
if verbose:
|
755 |
+
logger.warning(
|
756 |
+
f"doc.id={document.id}: Cleared {len(gold_anns_cleared)} gold and "
|
757 |
+
f"{len(pred_anns_cleared)} predicted annotations from layer {layer_name} "
|
758 |
+
f"because it depends on the relation layer {relation_layer}."
|
759 |
+
)
|
760 |
+
|
761 |
+
span2partition = {}
|
762 |
+
span_layer: AnnotationLayer
|
763 |
+
for span_layer in relation_annotation_layer.target_layers.values():
|
764 |
+
for span in list(span_layer) + list(span_layer.predictions):
|
765 |
+
if isinstance(span, Span):
|
766 |
+
span_start, span_end = span.start, span.end
|
767 |
+
elif isinstance(span, MultiSpan):
|
768 |
+
span_start, span_end = min(start for start, _ in span.slices), max(
|
769 |
+
end for _, end in span.slices
|
770 |
+
)
|
771 |
+
else:
|
772 |
+
raise ValueError(f"Unsupported span type: {type(span)}")
|
773 |
+
found_partition = False
|
774 |
+
for partition in result[partition_layer]:
|
775 |
+
if partition.start <= span_start and span_end <= partition.end:
|
776 |
+
span2partition[span] = partition
|
777 |
+
found_partition = True
|
778 |
+
break
|
779 |
+
if not found_partition:
|
780 |
+
raise ValueError(f"No partition found for span {span}")
|
781 |
+
|
782 |
+
if partition_label_whitelist is not None:
|
783 |
+
partition_label_whitelist_tuples = [tuple(pair) for pair in partition_label_whitelist]
|
784 |
+
else:
|
785 |
+
partition_label_whitelist_tuples = None
|
786 |
+
if partition_label_blacklist is not None:
|
787 |
+
partition_label_blacklist_tuples = [tuple(pair) for pair in partition_label_blacklist]
|
788 |
+
else:
|
789 |
+
partition_label_blacklist_tuples = None
|
790 |
+
|
791 |
+
for relation_base_layer in [relation_annotation_layer, relation_annotation_layer.predictions]:
|
792 |
+
# get all relations and clear the layer
|
793 |
+
relations = relation_base_layer.clear()
|
794 |
+
for relation in relations:
|
795 |
+
head_partition = span2partition[relation.head]
|
796 |
+
tail_partition = span2partition[relation.tail]
|
797 |
+
pair = (head_partition.label, tail_partition.label)
|
798 |
+
if (
|
799 |
+
partition_label_whitelist_tuples is None
|
800 |
+
or pair in partition_label_whitelist_tuples
|
801 |
+
) and (
|
802 |
+
partition_label_blacklist_tuples is None
|
803 |
+
or pair not in partition_label_blacklist_tuples
|
804 |
+
):
|
805 |
+
relation_base_layer.append(relation)
|
806 |
+
else:
|
807 |
+
if verbose:
|
808 |
+
logger.info(
|
809 |
+
f"Removing relation {relation} because its partitions "
|
810 |
+
f"({pair}) are not in the whitelist or are in the blacklist."
|
811 |
+
)
|
812 |
+
|
813 |
+
return result
|
src/evaluate.py
CHANGED
@@ -41,13 +41,13 @@ from omegaconf import DictConfig
|
|
41 |
from pie_datasets import DatasetDict
|
42 |
from pie_modules.models import * # noqa: F403
|
43 |
from pie_modules.taskmodules import * # noqa: F403
|
|
|
44 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
45 |
from pytorch_ie.models import * # noqa: F403
|
46 |
from pytorch_ie.taskmodules import * # noqa: F403
|
47 |
from pytorch_lightning import Trainer
|
48 |
|
49 |
from src import utils
|
50 |
-
from src.datamodules import PieDataModule
|
51 |
from src.models import * # noqa: F403
|
52 |
from src.taskmodules import * # noqa: F403
|
53 |
|
@@ -80,8 +80,8 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
80 |
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
|
81 |
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
|
82 |
|
83 |
-
# auto-convert the dataset if the
|
84 |
-
dataset =
|
85 |
|
86 |
# Init pytorch-ie datamodule
|
87 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
|
|
41 |
from pie_datasets import DatasetDict
|
42 |
from pie_modules.models import * # noqa: F403
|
43 |
from pie_modules.taskmodules import * # noqa: F403
|
44 |
+
from pytorch_ie import PieDataModule
|
45 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
46 |
from pytorch_ie.models import * # noqa: F403
|
47 |
from pytorch_ie.taskmodules import * # noqa: F403
|
48 |
from pytorch_lightning import Trainer
|
49 |
|
50 |
from src import utils
|
|
|
51 |
from src.models import * # noqa: F403
|
52 |
from src.taskmodules import * # noqa: F403
|
53 |
|
|
|
80 |
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
|
81 |
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
|
82 |
|
83 |
+
# auto-convert the dataset if the taskmodule specifies a document type
|
84 |
+
dataset = dataset.to_document_type(taskmodule, downcast=False)
|
85 |
|
86 |
# Init pytorch-ie datamodule
|
87 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
src/evaluate_documents.py
CHANGED
@@ -73,7 +73,7 @@ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
73 |
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
|
74 |
|
75 |
# auto-convert the dataset if the metric specifies a document type
|
76 |
-
dataset =
|
77 |
|
78 |
# Init lightning loggers
|
79 |
loggers = utils.instantiate_dict_entries(cfg, "logger")
|
|
|
73 |
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
|
74 |
|
75 |
# auto-convert the dataset if the metric specifies a document type
|
76 |
+
dataset = dataset.to_document_type(metric, downcast=False)
|
77 |
|
78 |
# Init lightning loggers
|
79 |
loggers = utils.instantiate_dict_entries(cfg, "logger")
|
src/hydra_callbacks/save_job_return_value.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3 |
import os
|
4 |
import pickle
|
5 |
from pathlib import Path
|
6 |
-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
@@ -174,6 +174,46 @@ def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "
|
|
174 |
return identifiers
|
175 |
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
class SaveJobReturnValueCallback(Callback):
|
178 |
"""Save the job return-value in ${output_dir}/{job_return_value_filename}.
|
179 |
|
@@ -200,6 +240,10 @@ class SaveJobReturnValueCallback(Callback):
|
|
200 |
multirun_create_ids_from_overrides: bool (default: True)
|
201 |
Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
|
202 |
identifier.
|
|
|
|
|
|
|
|
|
203 |
markdown_round_digits: int (default: 3)
|
204 |
The number of digits to round the values in the markdown file. If None, no rounding is applied.
|
205 |
multirun_job_id_key: str (default: "job_id")
|
@@ -220,6 +264,8 @@ class SaveJobReturnValueCallback(Callback):
|
|
220 |
integrate_multirun_result: bool = False,
|
221 |
multirun_aggregator_blacklist: Optional[List[str]] = None,
|
222 |
multirun_create_ids_from_overrides: bool = True,
|
|
|
|
|
223 |
markdown_round_digits: Optional[int] = 3,
|
224 |
multirun_job_id_key: str = "job_id",
|
225 |
paths_file: Optional[str] = None,
|
@@ -234,6 +280,8 @@ class SaveJobReturnValueCallback(Callback):
|
|
234 |
self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
|
235 |
self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
|
236 |
self.multirun_job_id_key = multirun_job_id_key
|
|
|
|
|
237 |
self.markdown_round_digits = markdown_round_digits
|
238 |
self.multirun_paths_file = multirun_paths_file
|
239 |
self.multirun_path_id = multirun_path_id
|
@@ -253,10 +301,21 @@ class SaveJobReturnValueCallback(Callback):
|
|
253 |
|
254 |
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
|
255 |
job_ids: Union[List[str], List[int]]
|
256 |
-
if self.
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
else:
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
if self.integrate_multirun_result:
|
262 |
# rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
|
@@ -368,6 +427,10 @@ class SaveJobReturnValueCallback(Callback):
|
|
368 |
if job_id_column in result.columns:
|
369 |
result = result.set_index(job_id_column)
|
370 |
result.index.name = self.multirun_job_id_key
|
|
|
|
|
|
|
|
|
371 |
else:
|
372 |
# Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
|
373 |
series = pd.Series(obj_py_flat)
|
|
|
3 |
import os
|
4 |
import pickle
|
5 |
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
|
|
174 |
return identifiers
|
175 |
|
176 |
|
177 |
+
def identifier2dict(
|
178 |
+
identifier: str, record_sep: str = "-", key_value_sep: str = "="
|
179 |
+
) -> Dict[str, str]:
|
180 |
+
"""Converts a single identifier to a dict. The identifier is expected to be separated by "-".
|
181 |
+
Values are allowed to contain "-" as well, but keys are not. Key and value are separated by "=".
|
182 |
+
|
183 |
+
Example:
|
184 |
+
>>> identifier = "a=1-b=my-stuff"
|
185 |
+
>>> identifier2dict(identifier)
|
186 |
+
{'a': '1', 'b': 'my-stuff'}
|
187 |
+
"""
|
188 |
+
parts = identifier.split(record_sep)
|
189 |
+
result = {}
|
190 |
+
last_key = None
|
191 |
+
for part in parts:
|
192 |
+
if key_value_sep in part:
|
193 |
+
last_key, value = part.split(key_value_sep, 1)
|
194 |
+
result[last_key] = value
|
195 |
+
else:
|
196 |
+
if last_key is None:
|
197 |
+
raise ValueError(
|
198 |
+
f'Invalid identifier: {identifier} (keys must not contain the record_sep="{record_sep}")'
|
199 |
+
)
|
200 |
+
result[last_key] += record_sep + part
|
201 |
+
return result
|
202 |
+
|
203 |
+
|
204 |
+
def identifiers_to_multiindex(identifiers: Iterable[str], **kwargs) -> pd.MultiIndex:
|
205 |
+
"""Converts a list of identifiers to a MultiIndex. See identifier2dict for the
|
206 |
+
format of the identifiers.
|
207 |
+
|
208 |
+
Example:
|
209 |
+
>>> identifiers = ["a=1-b=my-stuff", "a=2-b=yes", "a=3"]
|
210 |
+
>>> identifiers_to_multiindex(identifiers, record_sep="-", key_value_sep="=")
|
211 |
+
MultiIndex([(1, 'my-stuff'), (2, 'yes'), (3, nan)], names=['a', 'b'])
|
212 |
+
"""
|
213 |
+
frame = pd.DataFrame([identifier2dict(identifier, **kwargs) for identifier in identifiers])
|
214 |
+
return pd.MultiIndex.from_frame(frame)
|
215 |
+
|
216 |
+
|
217 |
class SaveJobReturnValueCallback(Callback):
|
218 |
"""Save the job return-value in ${output_dir}/{job_return_value_filename}.
|
219 |
|
|
|
240 |
multirun_create_ids_from_overrides: bool (default: True)
|
241 |
Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
|
242 |
identifier.
|
243 |
+
multirun_ids: List[str] or List[int] (default: None)
|
244 |
+
If provided, the job identifiers from the config are used instead of the overrides or the job index.
|
245 |
+
markdown_split_index: bool (default: False)
|
246 |
+
If True, the index of the markdown file is split into multiple columns based on the separator "-".
|
247 |
markdown_round_digits: int (default: 3)
|
248 |
The number of digits to round the values in the markdown file. If None, no rounding is applied.
|
249 |
multirun_job_id_key: str (default: "job_id")
|
|
|
264 |
integrate_multirun_result: bool = False,
|
265 |
multirun_aggregator_blacklist: Optional[List[str]] = None,
|
266 |
multirun_create_ids_from_overrides: bool = True,
|
267 |
+
multirun_ids: Optional[Union[List[str], List[int]]] = None,
|
268 |
+
markdown_split_index: bool = False,
|
269 |
markdown_round_digits: Optional[int] = 3,
|
270 |
multirun_job_id_key: str = "job_id",
|
271 |
paths_file: Optional[str] = None,
|
|
|
280 |
self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
|
281 |
self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
|
282 |
self.multirun_job_id_key = multirun_job_id_key
|
283 |
+
self.multirun_ids = multirun_ids
|
284 |
+
self.markdown_split_index = markdown_split_index
|
285 |
self.markdown_round_digits = markdown_round_digits
|
286 |
self.multirun_paths_file = multirun_paths_file
|
287 |
self.multirun_path_id = multirun_path_id
|
|
|
301 |
|
302 |
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
|
303 |
job_ids: Union[List[str], List[int]]
|
304 |
+
if self.multirun_ids is not None:
|
305 |
+
# use the job_ids from the config
|
306 |
+
if len(self.multirun_ids) != len(self.job_returns):
|
307 |
+
raise ValueError(
|
308 |
+
f"Number of job_ids ({len(self.multirun_ids)}) does not match number of job returns ({len(self.job_returns)})"
|
309 |
+
)
|
310 |
+
# convert ListConfig to list
|
311 |
+
job_ids = list(self.multirun_ids) # type: ignore
|
312 |
else:
|
313 |
+
if self.multirun_create_ids_from_overrides:
|
314 |
+
job_ids = overrides_to_identifiers(
|
315 |
+
[jr.overrides for jr in self.job_returns], sep="-"
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
job_ids = list(range(len(self.job_returns)))
|
319 |
|
320 |
if self.integrate_multirun_result:
|
321 |
# rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
|
|
|
427 |
if job_id_column in result.columns:
|
428 |
result = result.set_index(job_id_column)
|
429 |
result.index.name = self.multirun_job_id_key
|
430 |
+
|
431 |
+
if self.markdown_split_index:
|
432 |
+
result.index = identifiers_to_multiindex(result.index, record_sep="-")
|
433 |
+
result = result.reset_index()
|
434 |
else:
|
435 |
# Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
|
436 |
series = pd.Series(obj_py_flat)
|
src/langchain_modules/basic_pie_document_store.py
CHANGED
@@ -52,6 +52,7 @@ class BasicPieDocumentStore(PieDocumentStore):
|
|
52 |
shutil.rmtree(pie_documents_path)
|
53 |
os.makedirs(pie_documents_path, exist_ok=True)
|
54 |
doc_ids_iter = iter(self.client.yield_keys())
|
|
|
55 |
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
|
56 |
all_doc_ids.extend(batch_doc_ids)
|
57 |
docs = self.client.mget(batch_doc_ids)
|
@@ -63,7 +64,8 @@ class BasicPieDocumentStore(PieDocumentStore):
|
|
63 |
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
|
64 |
)
|
65 |
pie_dataset = Dataset.from_documents(pie_docs)
|
66 |
-
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
|
|
|
67 |
if len(all_doc_ids) > 0:
|
68 |
doc_ids_path = os.path.join(path, "doc_ids.json")
|
69 |
with open(doc_ids_path, "w") as f:
|
|
|
52 |
shutil.rmtree(pie_documents_path)
|
53 |
os.makedirs(pie_documents_path, exist_ok=True)
|
54 |
doc_ids_iter = iter(self.client.yield_keys())
|
55 |
+
mode = "w"
|
56 |
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
|
57 |
all_doc_ids.extend(batch_doc_ids)
|
58 |
docs = self.client.mget(batch_doc_ids)
|
|
|
64 |
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
|
65 |
)
|
66 |
pie_dataset = Dataset.from_documents(pie_docs)
|
67 |
+
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path, mode=mode)
|
68 |
+
mode = "a" # append after the first batch
|
69 |
if len(all_doc_ids) > 0:
|
70 |
doc_ids_path = os.path.join(path, "doc_ids.json")
|
71 |
with open(doc_ids_path, "w") as f:
|
src/langchain_modules/datasets_pie_document_store.py
CHANGED
@@ -118,7 +118,7 @@ class DatasetsPieDocumentStore(PieDocumentStore):
|
|
118 |
logger.warning(f"Removing existing directory: {pie_documents_path}")
|
119 |
shutil.rmtree(pie_documents_path)
|
120 |
os.makedirs(pie_documents_path, exist_ok=True)
|
121 |
-
DatasetDict({"train": self._data}).to_json(pie_documents_path)
|
122 |
doc_ids_path = os.path.join(path, "doc_ids.json")
|
123 |
with open(doc_ids_path, "w") as f:
|
124 |
json.dump(all_doc_ids, f)
|
|
|
118 |
logger.warning(f"Removing existing directory: {pie_documents_path}")
|
119 |
shutil.rmtree(pie_documents_path)
|
120 |
os.makedirs(pie_documents_path, exist_ok=True)
|
121 |
+
DatasetDict({"train": self._data}).to_json(pie_documents_path, mode="w")
|
122 |
doc_ids_path = os.path.join(path, "doc_ids.json")
|
123 |
with open(doc_ids_path, "w") as f:
|
124 |
json.dump(all_doc_ids, f)
|
src/metrics/__init__.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
-
from .
|
|
|
|
|
2 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
|
|
|
|
3 |
from .score_distribution import ScoreDistribution
|
|
|
|
|
|
1 |
+
from .connected_component_sizes import ConnectedComponentSizes
|
2 |
+
from .coref import CorefHoiF1
|
3 |
+
from .coref_sklearn import BinaryClassificationMetricsSKLearn
|
4 |
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
5 |
+
from .f1_with_threshold import F1WithThresholdMetric
|
6 |
+
from .ranking_sklearn import RankingMetricsSKLearn
|
7 |
from .score_distribution import ScoreDistribution
|
8 |
+
from .semantically_same_ranking import SemanticallySameRankingMetric
|
9 |
+
from .tpfpfn import TPFFPFNMetric
|
src/metrics/connected_component_sizes.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import Counter
|
3 |
+
from typing import Dict, List, TypeVar
|
4 |
+
|
5 |
+
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic
|
6 |
+
from pytorch_ie.annotations import BinaryRelation
|
7 |
+
|
8 |
+
from src.utils.graph_utils import get_connected_components
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
A = TypeVar("A")
|
13 |
+
|
14 |
+
|
15 |
+
# TODO: remove when "counts" aggregation function is available in DocumentStatistic
|
16 |
+
def count_func(values: List[int]) -> Dict[int, int]:
|
17 |
+
"""Counts the number of occurrences of each value in the list."""
|
18 |
+
counter = Counter(values)
|
19 |
+
result = {k: counter[k] for k in sorted(counter)}
|
20 |
+
return result
|
21 |
+
|
22 |
+
|
23 |
+
class ConnectedComponentSizes(DocumentStatistic):
|
24 |
+
# TODO: use "counts" aggregation function when available in DocumentStatistic
|
25 |
+
DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"]
|
26 |
+
|
27 |
+
def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None:
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
self.relation_layer = relation_layer
|
30 |
+
self.link_relation_label = link_relation_label
|
31 |
+
|
32 |
+
def _collect(self, document: Document) -> List[int]:
|
33 |
+
relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer]
|
34 |
+
spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer
|
35 |
+
|
36 |
+
connected_components: List[List] = get_connected_components(
|
37 |
+
elements=spans,
|
38 |
+
relations=relations,
|
39 |
+
link_relation_label=self.link_relation_label,
|
40 |
+
add_singletons=True,
|
41 |
+
)
|
42 |
+
new_component_sizes = [len(component) for component in connected_components]
|
43 |
+
return new_component_sizes
|
src/metrics/coref.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
from typing import Dict, Hashable, List, Optional, Sequence, Tuple, TypeVar
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from pytorch_ie import Annotation, Document, DocumentMetric
|
6 |
+
from pytorch_ie.annotations import BinaryRelation
|
7 |
+
|
8 |
+
from src.utils.graph_utils import get_connected_components
|
9 |
+
|
10 |
+
|
11 |
+
class CorefHoiEvaluator(object):
|
12 |
+
def __init__(self, metric, beta=1):
|
13 |
+
self.p_num = 0
|
14 |
+
self.p_den = 0
|
15 |
+
self.r_num = 0
|
16 |
+
self.r_den = 0
|
17 |
+
self.metric = metric
|
18 |
+
self.beta = beta
|
19 |
+
|
20 |
+
def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
|
21 |
+
if self.metric == ceafe_simplified:
|
22 |
+
pn, pd, rn, rd = self.metric(predicted, gold)
|
23 |
+
else:
|
24 |
+
pn, pd = self.metric(predicted, mention_to_gold)
|
25 |
+
rn, rd = self.metric(gold, mention_to_predicted)
|
26 |
+
self.p_num += pn
|
27 |
+
self.p_den += pd
|
28 |
+
self.r_num += rn
|
29 |
+
self.r_den += rd
|
30 |
+
|
31 |
+
def f1(self, p_num, p_den, r_num, r_den, beta=1):
|
32 |
+
p = 0 if p_den == 0 else p_num / float(p_den)
|
33 |
+
r = 0 if r_den == 0 else r_num / float(r_den)
|
34 |
+
return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
|
35 |
+
|
36 |
+
def get_f1(self):
|
37 |
+
return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
|
38 |
+
|
39 |
+
def get_recall(self):
|
40 |
+
return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
|
41 |
+
|
42 |
+
def get_precision(self):
|
43 |
+
return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
|
44 |
+
|
45 |
+
def get_prf(self):
|
46 |
+
return self.get_precision(), self.get_recall(), self.get_f1()
|
47 |
+
|
48 |
+
def get_counts(self):
|
49 |
+
return self.p_num, self.p_den, self.r_num, self.r_den
|
50 |
+
|
51 |
+
|
52 |
+
def b_cubed_simplified(clusters, mention_to_gold):
|
53 |
+
num, dem = 0, 0
|
54 |
+
for c in clusters:
|
55 |
+
if len(c) == 1:
|
56 |
+
continue
|
57 |
+
|
58 |
+
gold_counts = Counter()
|
59 |
+
correct = 0
|
60 |
+
for m in c:
|
61 |
+
if m in mention_to_gold:
|
62 |
+
gold_counts[tuple(mention_to_gold[m])] += 1
|
63 |
+
for c2, count in gold_counts.items():
|
64 |
+
if len(c2) != 1:
|
65 |
+
correct += count * count
|
66 |
+
|
67 |
+
num += correct / float(len(c))
|
68 |
+
dem += len(c)
|
69 |
+
return num, dem
|
70 |
+
|
71 |
+
|
72 |
+
def muc_simplified(clusters, mention_to_gold):
|
73 |
+
tp, p = 0, 0
|
74 |
+
for c in clusters:
|
75 |
+
p += len(c) - 1
|
76 |
+
tp += len(c)
|
77 |
+
linked = set()
|
78 |
+
for m in c:
|
79 |
+
if m in mention_to_gold:
|
80 |
+
linked.add(mention_to_gold[m])
|
81 |
+
else:
|
82 |
+
tp -= 1
|
83 |
+
tp -= len(linked)
|
84 |
+
return tp, p
|
85 |
+
|
86 |
+
|
87 |
+
def phi4_simplified(c1, c2):
|
88 |
+
return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
|
89 |
+
|
90 |
+
|
91 |
+
def ceafe_simplified(clusters, gold_clusters):
|
92 |
+
# lazy import to not force scipy installation
|
93 |
+
from scipy.optimize import linear_sum_assignment as linear_assignment
|
94 |
+
|
95 |
+
clusters = [c for c in clusters if len(c) != 1]
|
96 |
+
scores = np.zeros((len(gold_clusters), len(clusters)))
|
97 |
+
for i in range(len(gold_clusters)):
|
98 |
+
for j in range(len(clusters)):
|
99 |
+
scores[i, j] = phi4_simplified(gold_clusters[i], clusters[j])
|
100 |
+
matching = linear_assignment(-scores)
|
101 |
+
matching = np.transpose(np.asarray(matching))
|
102 |
+
similarity = sum(scores[matching[:, 0], matching[:, 1]])
|
103 |
+
return similarity, len(clusters), similarity, len(gold_clusters)
|
104 |
+
|
105 |
+
|
106 |
+
def lea_simplified(clusters, mention_to_gold):
|
107 |
+
num, dem = 0, 0
|
108 |
+
|
109 |
+
for c in clusters:
|
110 |
+
if len(c) == 1:
|
111 |
+
continue
|
112 |
+
|
113 |
+
common_links = 0
|
114 |
+
all_links = len(c) * (len(c) - 1) / 2.0
|
115 |
+
for i, m in enumerate(c):
|
116 |
+
if m in mention_to_gold:
|
117 |
+
for m2 in c[i + 1 :]:
|
118 |
+
if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
|
119 |
+
common_links += 1
|
120 |
+
|
121 |
+
num += len(c) * common_links / float(all_links)
|
122 |
+
dem += len(c)
|
123 |
+
|
124 |
+
return num, dem
|
125 |
+
|
126 |
+
|
127 |
+
H = TypeVar("H", bound=Hashable)
|
128 |
+
|
129 |
+
|
130 |
+
class CorefHoiF1(DocumentMetric):
|
131 |
+
"""
|
132 |
+
Coreference evaluation based on official coref-hoi evaluation script, i.e.,
|
133 |
+
https://github.com/lxucs/coref-hoi/blob/5ddfc3b64a5519c3555b5a57e47ab2f03c104a60/metrics.py.
|
134 |
+
|
135 |
+
The metric expects documents with a relation layer that contains binary relations
|
136 |
+
between mentions from the same coreference cluster. Works with relations targeting
|
137 |
+
mentions from multiple layers (e.g., cross-textual relations).
|
138 |
+
|
139 |
+
Args:
|
140 |
+
relation_layer: The name of the relation layer that contains the link relations.
|
141 |
+
include_singletons: If True (default), singletons will be included in the evaluation.
|
142 |
+
link_relation_label: If provided, only the relations with this label will be used
|
143 |
+
to create the clusters.
|
144 |
+
link_relation_relation_score_threshold: If provided, only the relations with a score
|
145 |
+
greater than or equal to this threshold will be used to create the clusters.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
relation_layer: str,
|
151 |
+
include_singletons: bool = True,
|
152 |
+
link_relation_label: Optional[str] = None,
|
153 |
+
link_relation_relation_score_threshold: Optional[float] = None,
|
154 |
+
) -> None:
|
155 |
+
super().__init__()
|
156 |
+
self.relation_layer = relation_layer
|
157 |
+
self.link_relation_label = link_relation_label
|
158 |
+
self.include_singletons = include_singletons
|
159 |
+
self.link_relation_relation_score_threshold = link_relation_relation_score_threshold
|
160 |
+
|
161 |
+
def reset(self) -> None:
|
162 |
+
self.evaluators = [
|
163 |
+
CorefHoiEvaluator(m) for m in (muc_simplified, b_cubed_simplified, ceafe_simplified)
|
164 |
+
]
|
165 |
+
|
166 |
+
def prepare_clusters_with_mapping(
|
167 |
+
self, mentions: Sequence[Annotation], relations: Sequence[BinaryRelation]
|
168 |
+
) -> Tuple[List[List[Annotation]], Dict[Annotation, Tuple[Annotation]]]:
|
169 |
+
|
170 |
+
# get connected components based on binary relations
|
171 |
+
connected_components = get_connected_components(
|
172 |
+
elements=mentions,
|
173 |
+
relations=relations,
|
174 |
+
link_relation_label=self.link_relation_label,
|
175 |
+
link_relation_relation_score_threshold=self.link_relation_relation_score_threshold,
|
176 |
+
add_singletons=self.include_singletons,
|
177 |
+
)
|
178 |
+
|
179 |
+
# store all clustered mentions in a list and
|
180 |
+
# create a map from each mention to its cluster
|
181 |
+
# (i.e. to the list of spans that includes all other mentions from the same cluster)
|
182 |
+
clusters = []
|
183 |
+
mention_to_cluster = dict()
|
184 |
+
for cluster in connected_components:
|
185 |
+
clusters.append(cluster)
|
186 |
+
for mention in cluster:
|
187 |
+
mention_to_cluster[mention] = tuple(cluster)
|
188 |
+
|
189 |
+
return clusters, mention_to_cluster
|
190 |
+
|
191 |
+
def _update(self, doc: Document) -> None:
|
192 |
+
relation_layer = doc[self.relation_layer]
|
193 |
+
gold_mentions = []
|
194 |
+
predicted_mentions = []
|
195 |
+
for mention_layer in relation_layer.target_layers.values():
|
196 |
+
gold_mentions.extend(mention_layer)
|
197 |
+
predicted_mentions.extend(mention_layer.predictions)
|
198 |
+
|
199 |
+
# prepare the clusters and mention-to-cluster mapping needed for evaluation
|
200 |
+
predicted_clusters, mention_to_predicted = self.prepare_clusters_with_mapping(
|
201 |
+
mentions=predicted_mentions, relations=relation_layer.predictions
|
202 |
+
)
|
203 |
+
gold_clusters, mention_to_gold = self.prepare_clusters_with_mapping(
|
204 |
+
mentions=gold_mentions, relations=relation_layer
|
205 |
+
)
|
206 |
+
|
207 |
+
for e in self.evaluators:
|
208 |
+
e.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
|
209 |
+
|
210 |
+
def get_f1(self) -> float:
|
211 |
+
return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
|
212 |
+
|
213 |
+
def get_recall(self) -> float:
|
214 |
+
return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
|
215 |
+
|
216 |
+
def get_precision(self) -> float:
|
217 |
+
return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
|
218 |
+
|
219 |
+
def get_prf(self) -> Tuple[float, float, float]:
|
220 |
+
return self.get_precision(), self.get_recall(), self.get_f1()
|
221 |
+
|
222 |
+
def _compute(self) -> float:
|
223 |
+
return self.get_f1()
|
src/metrics/coref_sklearn.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
import logging
|
2 |
import math
|
3 |
-
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
6 |
-
import torch
|
7 |
from pandas import MultiIndex
|
8 |
-
from pie_modules.
|
9 |
-
from pytorch_ie import DocumentMetric
|
10 |
from pytorch_ie.core.metric import T
|
11 |
from pytorch_ie.utils.hydra import resolve_target
|
12 |
-
from torchmetrics import Metric, MetricCollection
|
13 |
|
14 |
from src.hydra_callbacks.save_job_return_value import to_py_obj
|
15 |
|
@@ -24,6 +22,14 @@ def get_num_positives(targets: List[int], preds: List[float], positive_idx: int
|
|
24 |
return len([v for v in targets if v == positive_idx])
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def discretize(
|
28 |
values: List[float], threshold: Union[float, List[float], dict]
|
29 |
) -> Union[List[float], Dict[Any, List[float]]]:
|
@@ -40,20 +46,97 @@ def discretize(
|
|
40 |
raise TypeError(f"threshold has unknown type: {threshold}")
|
41 |
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def __init__(
|
47 |
self,
|
48 |
metrics: Dict[str, str],
|
|
|
|
|
49 |
thresholds: Optional[Dict[str, float]] = None,
|
50 |
default_target_idx: int = 0,
|
51 |
default_prediction_score: float = 0.0,
|
52 |
show_as_markdown: bool = False,
|
53 |
markdown_precision: int = 4,
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
):
|
56 |
-
self.metrics = {name:
|
57 |
self.thresholds = thresholds or {}
|
58 |
thresholds_not_in_metrics = {
|
59 |
name: t for name, t in self.thresholds.items() if name not in self.metrics
|
@@ -62,11 +145,25 @@ class CorefMetricsSKLearn(DocumentMetric):
|
|
62 |
logger.warning(
|
63 |
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
|
64 |
)
|
|
|
|
|
65 |
self.default_target_idx = default_target_idx
|
66 |
self.default_prediction_score = default_prediction_score
|
67 |
self.show_as_markdown = show_as_markdown
|
68 |
self.markdown_precision = markdown_precision
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
super().__init__()
|
72 |
|
@@ -74,50 +171,55 @@ class CorefMetricsSKLearn(DocumentMetric):
|
|
74 |
self._preds: List[float] = []
|
75 |
self._targets: List[int] = []
|
76 |
|
77 |
-
def _update(self, document:
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
}
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
}
|
84 |
-
all_args = set(
|
85 |
all_targets: List[int] = []
|
86 |
all_predictions: List[float] = []
|
87 |
for args in all_args:
|
88 |
-
target_idx =
|
89 |
-
prediction_score =
|
90 |
all_targets.append(target_idx)
|
91 |
all_predictions.append(prediction_score)
|
92 |
-
|
93 |
-
# target_indices = torch.tensor(all_targets)
|
94 |
-
# self.metrics.update(preds=prediction_scores, target=target_indices)
|
95 |
self._preds.extend(all_predictions)
|
96 |
self._targets.extend(all_targets)
|
97 |
|
98 |
-
def
|
99 |
-
raise NotImplementedError()
|
100 |
|
101 |
from matplotlib import pyplot as plt
|
102 |
|
103 |
# Get the number of metrics
|
104 |
-
|
105 |
|
106 |
# Calculate rows and columns for subplots (aim for a square-like layout)
|
107 |
-
ncols = math.ceil(math.sqrt(
|
108 |
-
nrows = math.ceil(
|
109 |
|
110 |
# Create the subplots
|
111 |
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
|
112 |
|
113 |
# Flatten the ax_list if necessary (in case of multiple rows/columns)
|
114 |
-
|
|
|
|
|
|
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
# Adjust layout to avoid overlapping plots
|
123 |
plt.tight_layout()
|
@@ -125,23 +227,35 @@ class CorefMetricsSKLearn(DocumentMetric):
|
|
125 |
|
126 |
def _compute(self) -> T:
|
127 |
|
128 |
-
if self.
|
129 |
-
self.
|
130 |
|
131 |
result = {}
|
132 |
for name, metric in self.metrics.items():
|
133 |
|
134 |
if name in self.thresholds:
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
else:
|
137 |
preds = self._preds
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
145 |
else:
|
146 |
result[name] = metric(self._targets, preds)
|
147 |
|
@@ -149,7 +263,8 @@ class CorefMetricsSKLearn(DocumentMetric):
|
|
149 |
if self.show_as_markdown:
|
150 |
import pandas as pd
|
151 |
|
152 |
-
|
|
|
153 |
if isinstance(series.index, MultiIndex):
|
154 |
if len(series.index.levels) > 1:
|
155 |
# in fact, this is not a series anymore
|
|
|
1 |
import logging
|
2 |
import math
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union, overload
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
from pandas import MultiIndex
|
7 |
+
from pie_modules.utils import flatten_dict
|
8 |
+
from pytorch_ie import Document, DocumentMetric
|
9 |
from pytorch_ie.core.metric import T
|
10 |
from pytorch_ie.utils.hydra import resolve_target
|
|
|
11 |
|
12 |
from src.hydra_callbacks.save_job_return_value import to_py_obj
|
13 |
|
|
|
22 |
return len([v for v in targets if v == positive_idx])
|
23 |
|
24 |
|
25 |
+
@overload
|
26 |
+
def discretize(values: List[float], threshold: float) -> List[float]: ...
|
27 |
+
|
28 |
+
|
29 |
+
@overload
|
30 |
+
def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ...
|
31 |
+
|
32 |
+
|
33 |
def discretize(
|
34 |
values: List[float], threshold: Union[float, List[float], dict]
|
35 |
) -> Union[List[float], Dict[Any, List[float]]]:
|
|
|
46 |
raise TypeError(f"threshold has unknown type: {threshold}")
|
47 |
|
48 |
|
49 |
+
def get_metric_func(name: str) -> Callable:
|
50 |
+
if name.endswith("_curve"):
|
51 |
+
from sklearn.metrics import auc
|
52 |
+
|
53 |
+
base_func = resolve_target(name)
|
54 |
+
|
55 |
+
def wrapper(targets: List[int], preds: List[float], **kwargs):
|
56 |
+
x, y, thresholds = base_func(targets, preds, **kwargs)
|
57 |
+
return auc(y, x)
|
58 |
+
|
59 |
+
return wrapper
|
60 |
+
else:
|
61 |
+
return resolve_target(name)
|
62 |
+
|
63 |
+
|
64 |
+
def bootstrap(
|
65 |
+
metric_fn: Callable[[List[int], Union[List[int], List[float]]], float],
|
66 |
+
targets: List[int],
|
67 |
+
predictions: Union[List[int], List[float]],
|
68 |
+
n: int = 1_000,
|
69 |
+
random_state: int | None = None,
|
70 |
+
alpha: float = 0.95,
|
71 |
+
) -> Dict[str, float]:
|
72 |
+
"""
|
73 |
+
Returns mean and a two–sided (1–alpha) bootstrap CI for any
|
74 |
+
pair-wise classification or ranking metric.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
metric_fn Metric function taking (targets, prediction) lists.
|
79 |
+
targets Ground-truth 0/1 labels.
|
80 |
+
prediction Scores or hard predictions (same length as `targets`).
|
81 |
+
n Number of bootstrap replicates (after skipping degenerate ones).
|
82 |
+
random_state Seed for reproducibility.
|
83 |
+
alpha Confidence level (default 0.95 → 95 % CI).
|
84 |
+
|
85 |
+
Notes
|
86 |
+
-----
|
87 |
+
* A replicate that contains only one class is discarded
|
88 |
+
because many sklearn metrics are undefined in that case.
|
89 |
+
* If all replicates are discarded an exception is raised.
|
90 |
+
"""
|
91 |
+
y = np.asarray(targets)
|
92 |
+
yhat = np.asarray(predictions)
|
93 |
+
if y.shape[0] != yhat.shape[0]:
|
94 |
+
raise ValueError("`targets` and `prediction` must have the same length")
|
95 |
+
|
96 |
+
rng = np.random.default_rng(random_state)
|
97 |
+
idx = np.arange(y.shape[0])
|
98 |
+
vals_list: list[float] = []
|
99 |
+
|
100 |
+
while len(vals_list) < n:
|
101 |
+
sample_idx = rng.choice(idx, size=idx.shape[0], replace=True)
|
102 |
+
y_samp, yhat_samp = y[sample_idx], yhat[sample_idx]
|
103 |
+
|
104 |
+
# skip all-positive or all-negative bootstrap samples
|
105 |
+
if y_samp.min() == y_samp.max():
|
106 |
+
continue
|
107 |
+
|
108 |
+
vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist()))
|
109 |
+
|
110 |
+
if not vals_list:
|
111 |
+
raise RuntimeError("No valid bootstrap replicate contained both classes.")
|
112 |
+
|
113 |
+
vals = np.asarray(vals_list, dtype=float)
|
114 |
+
lower = np.percentile(vals, (1 - alpha) / 2 * 100)
|
115 |
+
upper = np.percentile(vals, (1 + alpha) / 2 * 100)
|
116 |
+
|
117 |
+
return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)}
|
118 |
+
|
119 |
+
|
120 |
+
class BinaryClassificationMetricsSKLearn(DocumentMetric):
|
121 |
|
122 |
def __init__(
|
123 |
self,
|
124 |
metrics: Dict[str, str],
|
125 |
+
layer: str,
|
126 |
+
label: Optional[str] = None,
|
127 |
thresholds: Optional[Dict[str, float]] = None,
|
128 |
default_target_idx: int = 0,
|
129 |
default_prediction_score: float = 0.0,
|
130 |
show_as_markdown: bool = False,
|
131 |
markdown_precision: int = 4,
|
132 |
+
bootstrap: Optional[list[str]] = None,
|
133 |
+
bootstrap_n: int = 1_000,
|
134 |
+
bootstrap_random_state: int | None = None,
|
135 |
+
bootstrap_alpha: float = 0.95,
|
136 |
+
create_plots: bool = True,
|
137 |
+
plots: Optional[Dict[str, str]] = None,
|
138 |
):
|
139 |
+
self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()}
|
140 |
self.thresholds = thresholds or {}
|
141 |
thresholds_not_in_metrics = {
|
142 |
name: t for name, t in self.thresholds.items() if name not in self.metrics
|
|
|
145 |
logger.warning(
|
146 |
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
|
147 |
)
|
148 |
+
self.annotation_layer_name = layer
|
149 |
+
self.annotation_label = label
|
150 |
self.default_target_idx = default_target_idx
|
151 |
self.default_prediction_score = default_prediction_score
|
152 |
self.show_as_markdown = show_as_markdown
|
153 |
self.markdown_precision = markdown_precision
|
154 |
+
if create_plots:
|
155 |
+
self.plots = {
|
156 |
+
name: resolve_target(plot_func) for name, plot_func in (plots or {}).items()
|
157 |
+
}
|
158 |
+
else:
|
159 |
+
self.plots = {}
|
160 |
+
|
161 |
+
self.bootstrap = set(bootstrap or [])
|
162 |
+
self.bootstrap_kwargs = {
|
163 |
+
"n": bootstrap_n,
|
164 |
+
"random_state": bootstrap_random_state,
|
165 |
+
"alpha": bootstrap_alpha,
|
166 |
+
}
|
167 |
|
168 |
super().__init__()
|
169 |
|
|
|
171 |
self._preds: List[float] = []
|
172 |
self._targets: List[int] = []
|
173 |
|
174 |
+
def _update(self, document: Document) -> None:
|
175 |
+
annotation_layer = document[self.annotation_layer_name]
|
176 |
+
target2idx = {
|
177 |
+
ann: int(ann.score)
|
178 |
+
for ann in annotation_layer
|
179 |
+
if self.annotation_label is None or ann.label == self.annotation_label
|
180 |
}
|
181 |
+
prediction2score = {
|
182 |
+
ann: ann.score
|
183 |
+
for ann in annotation_layer.predictions
|
184 |
+
if self.annotation_label is None or ann.label == self.annotation_label
|
185 |
}
|
186 |
+
all_args = set(target2idx) | set(prediction2score)
|
187 |
all_targets: List[int] = []
|
188 |
all_predictions: List[float] = []
|
189 |
for args in all_args:
|
190 |
+
target_idx = target2idx.get(args, self.default_target_idx)
|
191 |
+
prediction_score = prediction2score.get(args, self.default_prediction_score)
|
192 |
all_targets.append(target_idx)
|
193 |
all_predictions.append(prediction_score)
|
194 |
+
|
|
|
|
|
195 |
self._preds.extend(all_predictions)
|
196 |
self._targets.extend(all_targets)
|
197 |
|
198 |
+
def create_plots(self):
|
|
|
199 |
|
200 |
from matplotlib import pyplot as plt
|
201 |
|
202 |
# Get the number of metrics
|
203 |
+
num_plots = len(self.plots)
|
204 |
|
205 |
# Calculate rows and columns for subplots (aim for a square-like layout)
|
206 |
+
ncols = math.ceil(math.sqrt(num_plots))
|
207 |
+
nrows = math.ceil(num_plots / ncols)
|
208 |
|
209 |
# Create the subplots
|
210 |
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
|
211 |
|
212 |
# Flatten the ax_list if necessary (in case of multiple rows/columns)
|
213 |
+
if num_plots > 1:
|
214 |
+
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
|
215 |
+
else:
|
216 |
+
ax_list = [ax_list]
|
217 |
|
218 |
+
# Create each plot
|
219 |
+
for ax, (name, plot_func) in zip(ax_list, self.plots.items()):
|
220 |
+
# Set the title for each subplot
|
221 |
+
ax.set_title(name)
|
222 |
+
plot_func(y_true=self._targets, y_pred=self._preds, ax=ax)
|
223 |
|
224 |
# Adjust layout to avoid overlapping plots
|
225 |
plt.tight_layout()
|
|
|
227 |
|
228 |
def _compute(self) -> T:
|
229 |
|
230 |
+
if len(self.plots) > 0:
|
231 |
+
self.create_plots()
|
232 |
|
233 |
result = {}
|
234 |
for name, metric in self.metrics.items():
|
235 |
|
236 |
if name in self.thresholds:
|
237 |
+
preds_dict = discretize(values=self._preds, threshold=self.thresholds[name])
|
238 |
+
if isinstance(preds_dict, dict):
|
239 |
+
metric_results = {
|
240 |
+
t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items()
|
241 |
+
}
|
242 |
+
# just get the max
|
243 |
+
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
|
244 |
+
result[f"{name}_threshold"] = max_t
|
245 |
+
preds = discretize(values=self._preds, threshold=max_t)
|
246 |
+
else:
|
247 |
+
preds = preds_dict
|
248 |
else:
|
249 |
preds = self._preds
|
250 |
+
|
251 |
+
if name in self.bootstrap:
|
252 |
+
# bootstrap the metric
|
253 |
+
result[name] = bootstrap(
|
254 |
+
metric_fn=metric,
|
255 |
+
targets=self._targets,
|
256 |
+
predictions=preds,
|
257 |
+
**self.bootstrap_kwargs, # type: ignore
|
258 |
+
)
|
259 |
else:
|
260 |
result[name] = metric(self._targets, preds)
|
261 |
|
|
|
263 |
if self.show_as_markdown:
|
264 |
import pandas as pd
|
265 |
|
266 |
+
result_flat = flatten_dict(result)
|
267 |
+
series = pd.Series(result_flat)
|
268 |
if isinstance(series.index, MultiIndex):
|
269 |
if len(series.index.levels) > 1:
|
270 |
# in fact, this is not a series anymore
|
src/metrics/f1_with_bootstrapping.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set
|
4 |
+
|
5 |
+
from pie_modules.metrics import F1Metric
|
6 |
+
from pytorch_ie import Annotation, Document
|
7 |
+
|
8 |
+
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
|
9 |
+
return getattr(ann, label_field) in labels
|
10 |
+
|
11 |
+
|
12 |
+
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
|
13 |
+
return getattr(ann, label_field) == label
|
14 |
+
|
15 |
+
|
16 |
+
class F1WithBootstrappingMetric(F1Metric):
|
17 |
+
def __init__(self, *args, bootstrap_n: int = 0, **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
self.bootstrap_n = bootstrap_n
|
20 |
+
|
21 |
+
|
22 |
+
def reset(self) -> None:
|
23 |
+
self.tp: Dict[str, Set[Annotation]] = defaultdict(set)
|
24 |
+
self.fp: Dict[str, Set[Annotation]] = defaultdict(set)
|
25 |
+
self.fn: Dict[str, Set[Annotation]] = defaultdict(set)
|
26 |
+
|
27 |
+
def calculate_tp_fp_fn(
|
28 |
+
self,
|
29 |
+
document: Document,
|
30 |
+
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
|
31 |
+
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
|
32 |
+
) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]:
|
33 |
+
annotation_processor = annotation_processor or (lambda ann: ann)
|
34 |
+
annotation_filter = annotation_filter or (lambda ann: True)
|
35 |
+
predicted_annotations = {
|
36 |
+
annotation_processor(ann)
|
37 |
+
for ann in document[self.layer].predictions
|
38 |
+
if annotation_filter(ann)
|
39 |
+
}
|
40 |
+
gold_annotations = {
|
41 |
+
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
|
42 |
+
}
|
43 |
+
return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations
|
44 |
+
|
45 |
+
|
46 |
+
def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None:
|
47 |
+
self.tp[label].update(tp)
|
48 |
+
self.fp[label].update(fp)
|
49 |
+
self.fn[label].update(fn)
|
50 |
+
|
51 |
+
def _update(self, document: Document) -> None:
|
52 |
+
new_values = self.calculate_tp_fp_fn(
|
53 |
+
document=document,
|
54 |
+
annotation_filter=(
|
55 |
+
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
|
56 |
+
if self.per_label and not self.infer_labels
|
57 |
+
else None
|
58 |
+
),
|
59 |
+
annotation_processor=self.annotation_processor,
|
60 |
+
)
|
61 |
+
self.add_tp_fp_fn(*new_values, label="MICRO")
|
62 |
+
if self.infer_labels:
|
63 |
+
layer = document[self.layer]
|
64 |
+
# collect labels from gold data and predictions
|
65 |
+
for ann in list(layer) + list(layer.predictions):
|
66 |
+
label = getattr(ann, self.label_field)
|
67 |
+
if label not in self.labels:
|
68 |
+
self.labels.append(label)
|
69 |
+
if self.per_label:
|
70 |
+
for label in self.labels:
|
71 |
+
new_values = self.calculate_tp_fp_fn(
|
72 |
+
document=document,
|
73 |
+
annotation_filter=partial(
|
74 |
+
has_this_label, label_field=self.label_field, label=label
|
75 |
+
),
|
76 |
+
annotation_processor=self.annotation_processor,
|
77 |
+
)
|
78 |
+
self.add_tp_fp_fn(*new_values, label=label)
|
79 |
+
|
80 |
+
def _compute(self) -> Dict[str, Dict[str, float]]:
|
81 |
+
res = dict()
|
82 |
+
if self.per_label:
|
83 |
+
res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0}
|
84 |
+
for label in self.tp.keys():
|
85 |
+
tp, fp, fn = (
|
86 |
+
len(self.tp[label]),
|
87 |
+
len(self.fp[label]),
|
88 |
+
len(self.fn[label]),
|
89 |
+
)
|
90 |
+
if tp == 0:
|
91 |
+
p, r, f1 = 0.0, 0.0, 0.0
|
92 |
+
else:
|
93 |
+
p = tp / (tp + fp)
|
94 |
+
r = tp / (tp + fn)
|
95 |
+
f1 = 2 * p * r / (p + r)
|
96 |
+
res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn}
|
97 |
+
if self.per_label and label in self.labels:
|
98 |
+
res["MACRO"]["f1"] += f1 / len(self.labels)
|
99 |
+
res["MACRO"]["p"] += p / len(self.labels)
|
100 |
+
res["MACRO"]["r"] += r / len(self.labels)
|
101 |
+
if self.show_as_markdown:
|
102 |
+
logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
|
103 |
+
return res
|
src/metrics/f1_with_threshold.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Hashable, Optional, Tuple
|
2 |
+
|
3 |
+
from pie_modules.metrics import F1Metric
|
4 |
+
from pytorch_ie import Annotation, Document
|
5 |
+
|
6 |
+
|
7 |
+
class F1WithThresholdMetric(F1Metric):
|
8 |
+
def __init__(self, *args, threshold: float = 0.0, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
self.threshold = threshold
|
11 |
+
|
12 |
+
def calculate_counts(
|
13 |
+
self,
|
14 |
+
document: Document,
|
15 |
+
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
|
16 |
+
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
|
17 |
+
) -> Tuple[int, int, int]:
|
18 |
+
annotation_processor = annotation_processor or (lambda ann: ann)
|
19 |
+
annotation_filter = annotation_filter or (lambda ann: True)
|
20 |
+
predicted_annotations = {
|
21 |
+
annotation_processor(ann)
|
22 |
+
for ann in document[self.layer].predictions
|
23 |
+
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
|
24 |
+
}
|
25 |
+
gold_annotations = {
|
26 |
+
annotation_processor(ann)
|
27 |
+
for ann in document[self.layer]
|
28 |
+
if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
|
29 |
+
}
|
30 |
+
tp = len([ann for ann in predicted_annotations & gold_annotations])
|
31 |
+
fn = len([ann for ann in gold_annotations - predicted_annotations])
|
32 |
+
fp = len([ann for ann in predicted_annotations - gold_annotations])
|
33 |
+
return tp, fp, fn
|
src/metrics/ranking_sklearn.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Callable, Dict, List, Optional, Sequence, Union
|
4 |
+
|
5 |
+
from pandas import MultiIndex
|
6 |
+
from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric
|
7 |
+
from pytorch_ie.annotations import BinaryRelation
|
8 |
+
from pytorch_ie.core.metric import T
|
9 |
+
from pytorch_ie.utils.hydra import resolve_target
|
10 |
+
|
11 |
+
from src.hydra_callbacks.save_job_return_value import to_py_obj
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class RankingMetricsSKLearn(DocumentMetric):
|
17 |
+
"""Ranking metrics for documents with binary relations.
|
18 |
+
|
19 |
+
This metric computes the ranking metrics for retrieval tasks, where
|
20 |
+
relation heads are the queries and the relation tails are the candidates.
|
21 |
+
The metric is computed for each head and the results are averaged. It is meant to
|
22 |
+
be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized
|
23 |
+
Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score`
|
24 |
+
(LRAP), etc., see
|
25 |
+
https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their
|
29 |
+
corresponding functions. The function can be a string (name of the function, e.g.,
|
30 |
+
sklearn.metrics.ndcg_score) or a callable.
|
31 |
+
layer (str): The name of the annotation layer containing the binary relations, e.g.,
|
32 |
+
"binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations.
|
33 |
+
use_manual_average (Optional[List[str]]): A list of metric names to use for manual
|
34 |
+
averaging. If provided, the metric scores will be calculated for each
|
35 |
+
head and then averaged. Otherwise, all true and predicted scores will be
|
36 |
+
passed to the metric function at once.
|
37 |
+
exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons
|
38 |
+
from the computation, i.e., entries (heads) where the number of candidates is 1.
|
39 |
+
label (Optional[str]): If provided, only the relations with this label will be used
|
40 |
+
to compute the metrics. This is useful for filtering out relations that are not
|
41 |
+
relevant for the task at hand (e.g., when having multiple relation types in the
|
42 |
+
same layer).
|
43 |
+
score_threshold (float): If provided, only the relations with a score greater than or
|
44 |
+
equal to this threshold will be used to compute the metrics.
|
45 |
+
default_score (float): The default score to use for missing relations, either in the
|
46 |
+
target or prediction. Default is 0.0.
|
47 |
+
use_all_spans (bool): Whether to consider all spans in the document as queries and
|
48 |
+
candidates or only the spans that are present in the target and prediction.
|
49 |
+
span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with
|
50 |
+
heads/tails that are in this list. When using use_all_spans=True, this also
|
51 |
+
restricts the candidates to those that are not in the blacklist.
|
52 |
+
show_as_markdown (bool): Whether to show the results as markdown. Default is False.
|
53 |
+
markdown_precision (int): The precision for displaying the results in markdown.
|
54 |
+
Default is 4.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
metrics: Dict[str, Union[str, Callable]],
|
60 |
+
layer: str,
|
61 |
+
use_manual_average: Optional[List[str]] = None,
|
62 |
+
exclude_singletons: Optional[List[str]] = None,
|
63 |
+
label: Optional[str] = None,
|
64 |
+
score_threshold: float = 0.0,
|
65 |
+
default_score: float = 0.0,
|
66 |
+
use_all_spans: bool = False,
|
67 |
+
span_label_blacklist: Optional[List[str]] = None,
|
68 |
+
show_as_markdown: bool = False,
|
69 |
+
markdown_precision: int = 4,
|
70 |
+
plot: bool = False,
|
71 |
+
):
|
72 |
+
self.metrics = {
|
73 |
+
name: resolve_target(metric) if isinstance(metric, str) else metric
|
74 |
+
for name, metric in metrics.items()
|
75 |
+
}
|
76 |
+
self.use_manual_average = set(use_manual_average or [])
|
77 |
+
self.exclude_singletons = set(exclude_singletons or [])
|
78 |
+
self.annotation_layer_name = layer
|
79 |
+
self.annotation_label = label
|
80 |
+
self.score_threshold = score_threshold
|
81 |
+
self.default_score = default_score
|
82 |
+
self.use_all_spans = use_all_spans
|
83 |
+
self.span_label_blacklist = span_label_blacklist
|
84 |
+
self.show_as_markdown = show_as_markdown
|
85 |
+
self.markdown_precision = markdown_precision
|
86 |
+
self.plot = plot
|
87 |
+
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
def reset(self) -> None:
|
91 |
+
self._preds: List[List[float]] = []
|
92 |
+
self._targets: List[List[float]] = []
|
93 |
+
|
94 |
+
def get_head2tail2score(
|
95 |
+
self, relations: Sequence[BinaryRelation]
|
96 |
+
) -> Dict[Annotation, Dict[Annotation, float]]:
|
97 |
+
result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict)
|
98 |
+
for rel in relations:
|
99 |
+
if (
|
100 |
+
(self.annotation_label is None or rel.label == self.annotation_label)
|
101 |
+
and (rel.score >= self.score_threshold)
|
102 |
+
and (
|
103 |
+
self.span_label_blacklist is None
|
104 |
+
or (
|
105 |
+
rel.head.label not in self.span_label_blacklist
|
106 |
+
and rel.tail.label not in self.span_label_blacklist
|
107 |
+
)
|
108 |
+
)
|
109 |
+
):
|
110 |
+
result[rel.head][rel.tail] = rel.score
|
111 |
+
|
112 |
+
return result
|
113 |
+
|
114 |
+
def _update(self, document: Document) -> None:
|
115 |
+
annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name]
|
116 |
+
|
117 |
+
target_head2tail2score = self.get_head2tail2score(annotation_layer)
|
118 |
+
prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions)
|
119 |
+
all_spans = set()
|
120 |
+
# get spans from all layers targeted by the annotation (relation) layer
|
121 |
+
for span_layer in annotation_layer.target_layers.values():
|
122 |
+
all_spans.update(span_layer)
|
123 |
+
|
124 |
+
if self.span_label_blacklist is not None:
|
125 |
+
all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist}
|
126 |
+
|
127 |
+
if self.use_all_spans:
|
128 |
+
all_heads = all_spans
|
129 |
+
else:
|
130 |
+
all_heads = set(target_head2tail2score) | set(prediction_head2tail2score)
|
131 |
+
|
132 |
+
all_targets: List[List[float]] = []
|
133 |
+
all_predictions: List[List[float]] = []
|
134 |
+
for head in all_heads:
|
135 |
+
target_tail2score = target_head2tail2score.get(head, {})
|
136 |
+
prediction_tail2score = prediction_head2tail2score.get(head, {})
|
137 |
+
if self.use_all_spans:
|
138 |
+
# use all spans as tails
|
139 |
+
tails = set(span for span in all_spans if span != head)
|
140 |
+
else:
|
141 |
+
# use only the tails that are in the target or prediction
|
142 |
+
tails = set(target_tail2score) | set(prediction_tail2score)
|
143 |
+
target_scores = [target_tail2score.get(t, self.default_score) for t in tails]
|
144 |
+
prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails]
|
145 |
+
all_targets.append(target_scores)
|
146 |
+
all_predictions.append(prediction_scores)
|
147 |
+
|
148 |
+
self._targets.extend(all_targets)
|
149 |
+
self._preds.extend(all_predictions)
|
150 |
+
|
151 |
+
def do_plot(self):
|
152 |
+
raise NotImplementedError()
|
153 |
+
|
154 |
+
def _compute(self) -> T:
|
155 |
+
|
156 |
+
if self.plot:
|
157 |
+
self.do_plot()
|
158 |
+
|
159 |
+
result = {}
|
160 |
+
for name, metric in self.metrics.items():
|
161 |
+
targets, preds = self._targets, self._preds
|
162 |
+
if name in self.exclude_singletons:
|
163 |
+
targets = [t for t in targets if len(t) > 1]
|
164 |
+
preds = [p for p in preds if len(p) > 1]
|
165 |
+
num_singletons = len(self._targets) - len(targets)
|
166 |
+
logger.warning(
|
167 |
+
f"Excluding {num_singletons} singletons (out of {len(self._targets)} "
|
168 |
+
f"entries) from {name} metric calculation."
|
169 |
+
)
|
170 |
+
|
171 |
+
if name in self.use_manual_average:
|
172 |
+
scores = [
|
173 |
+
metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds)
|
174 |
+
]
|
175 |
+
result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0
|
176 |
+
else:
|
177 |
+
result[name] = metric(y_true=targets, y_score=preds)
|
178 |
+
|
179 |
+
result = to_py_obj(result)
|
180 |
+
if self.show_as_markdown:
|
181 |
+
import pandas as pd
|
182 |
+
|
183 |
+
series = pd.Series(result)
|
184 |
+
if isinstance(series.index, MultiIndex):
|
185 |
+
if len(series.index.levels) > 1:
|
186 |
+
# in fact, this is not a series anymore
|
187 |
+
series = series.unstack(-1)
|
188 |
+
else:
|
189 |
+
series.index = series.index.get_level_values(0)
|
190 |
+
logger.info(
|
191 |
+
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
|
192 |
+
)
|
193 |
+
return result
|
src/metrics/score_distribution.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
1 |
from collections import defaultdict
|
2 |
from typing import Any, Dict, List, Optional, Tuple
|
3 |
|
4 |
import pandas as pd
|
5 |
from pytorch_ie import Document, DocumentMetric
|
6 |
|
|
|
|
|
7 |
|
8 |
class ScoreDistribution(DocumentMetric):
|
9 |
"""Computes the distribution of prediction scores for annotations in a layer. The scores are
|
@@ -36,7 +39,8 @@ class ScoreDistribution(DocumentMetric):
|
|
36 |
plotly_use_create_distplot: bool = True,
|
37 |
plotly_barmode: Optional[str] = None,
|
38 |
plotly_marginal: Optional[str] = "violin",
|
39 |
-
|
|
|
40 |
plotly_font_family: Optional[str] = None,
|
41 |
plotly_background_color: Optional[str] = None,
|
42 |
):
|
@@ -52,7 +56,12 @@ class ScoreDistribution(DocumentMetric):
|
|
52 |
self.plotly_use_create_distplot = plotly_use_create_distplot
|
53 |
self.plotly_barmode = plotly_barmode
|
54 |
self.plotly_marginal = plotly_marginal
|
55 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
56 |
self.plotly_font_family = plotly_font_family
|
57 |
self.plotly_background_color = plotly_background_color
|
58 |
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
@@ -231,7 +240,7 @@ class ScoreDistribution(DocumentMetric):
|
|
231 |
width=800,
|
232 |
title_text=description,
|
233 |
title_x=0.5,
|
234 |
-
font=
|
235 |
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
|
236 |
)
|
237 |
if self.plotly_barmode is not None:
|
@@ -290,7 +299,7 @@ class ScoreDistribution(DocumentMetric):
|
|
290 |
width=800,
|
291 |
title_text=f"Mean Binned Scores for {self.mapped_layer}",
|
292 |
title_x=0.5,
|
293 |
-
font=
|
294 |
)
|
295 |
fig.update_layout(
|
296 |
legend=dict(
|
|
|
1 |
+
import logging
|
2 |
from collections import defaultdict
|
3 |
from typing import Any, Dict, List, Optional, Tuple
|
4 |
|
5 |
import pandas as pd
|
6 |
from pytorch_ie import Document, DocumentMetric
|
7 |
|
8 |
+
logger = logging.getLogger()
|
9 |
+
|
10 |
|
11 |
class ScoreDistribution(DocumentMetric):
|
12 |
"""Computes the distribution of prediction scores for annotations in a layer. The scores are
|
|
|
39 |
plotly_use_create_distplot: bool = True,
|
40 |
plotly_barmode: Optional[str] = None,
|
41 |
plotly_marginal: Optional[str] = "violin",
|
42 |
+
plotly_font: Optional[Dict[str, Any]] = None,
|
43 |
+
plotly_font_size: Optional[int] = None,
|
44 |
plotly_font_family: Optional[str] = None,
|
45 |
plotly_background_color: Optional[str] = None,
|
46 |
):
|
|
|
56 |
self.plotly_use_create_distplot = plotly_use_create_distplot
|
57 |
self.plotly_barmode = plotly_barmode
|
58 |
self.plotly_marginal = plotly_marginal
|
59 |
+
self.plotly_font = plotly_font or {}
|
60 |
+
if plotly_font_size is not None:
|
61 |
+
logger.warning(
|
62 |
+
"Parameter 'plotly_font_size' is deprecated. Use 'plotly_font' with 'size' key instead."
|
63 |
+
)
|
64 |
+
self.plotly_font["size"] = plotly_font_size
|
65 |
self.plotly_font_family = plotly_font_family
|
66 |
self.plotly_background_color = plotly_background_color
|
67 |
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
|
|
240 |
width=800,
|
241 |
title_text=description,
|
242 |
title_x=0.5,
|
243 |
+
font=self.plotly_font,
|
244 |
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
|
245 |
)
|
246 |
if self.plotly_barmode is not None:
|
|
|
299 |
width=800,
|
300 |
title_text=f"Mean Binned Scores for {self.mapped_layer}",
|
301 |
title_x=0.5,
|
302 |
+
font=self.plotly_font,
|
303 |
)
|
304 |
fig.update_layout(
|
305 |
legend=dict(
|
src/metrics/semantically_same_ranking.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import warnings
|
3 |
+
from collections import defaultdict
|
4 |
+
from functools import partial
|
5 |
+
from typing import Callable, Iterable, List, Optional, Set, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from pytorch_ie import DocumentMetric
|
10 |
+
from pytorch_ie.annotations import BinaryRelation
|
11 |
+
from sklearn.metrics import average_precision_score, ndcg_score
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
NEG_INF = -1e9 # smaller than any real score
|
16 |
+
|
17 |
+
# metrics
|
18 |
+
|
19 |
+
|
20 |
+
def true_mrr(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float:
|
21 |
+
"""
|
22 |
+
Macro MRR over *all* queries.
|
23 |
+
• Reciprocal rank is 0 when a query has no relevant item.
|
24 |
+
• If k is given, restrict the search to the top-k list.
|
25 |
+
"""
|
26 |
+
if y_true.size == 0:
|
27 |
+
return np.nan
|
28 |
+
|
29 |
+
rr = []
|
30 |
+
for t, s in zip(y_true, y_score):
|
31 |
+
if t.sum() == 0:
|
32 |
+
rr.append(0.0)
|
33 |
+
continue
|
34 |
+
|
35 |
+
order = np.argsort(-s)
|
36 |
+
if k is not None:
|
37 |
+
order = order[:k]
|
38 |
+
|
39 |
+
# first position where t == 1, +1 for 1-based rank
|
40 |
+
first_hit = np.flatnonzero(t[order] > 0)
|
41 |
+
rank = first_hit[0] + 1 if first_hit.size else np.inf
|
42 |
+
rr.append(0.0 if np.isinf(rank) else 1.0 / rank)
|
43 |
+
|
44 |
+
return np.mean(rr)
|
45 |
+
|
46 |
+
|
47 |
+
def macro_ndcg(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float:
|
48 |
+
"""
|
49 |
+
Macro NDCG@k over all queries.
|
50 |
+
|
51 |
+
ndcg_score returns 0 when a query has no positives, so no masking is required.
|
52 |
+
"""
|
53 |
+
if y_true.size == 0:
|
54 |
+
return np.nan
|
55 |
+
return ndcg_score(y_true, y_score, k=k)
|
56 |
+
|
57 |
+
|
58 |
+
def macro_map(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
59 |
+
"""
|
60 |
+
Macro MAP: mean of Average-Precision per query.
|
61 |
+
Queries without positives contribute AP = 0.
|
62 |
+
"""
|
63 |
+
if y_true.size == 0:
|
64 |
+
return np.nan
|
65 |
+
|
66 |
+
ap = []
|
67 |
+
for t, s in zip(y_true, y_score):
|
68 |
+
if t.sum() == 0:
|
69 |
+
ap.append(0.0)
|
70 |
+
else:
|
71 |
+
ap.append(average_precision_score(t, s))
|
72 |
+
return np.mean(ap)
|
73 |
+
|
74 |
+
|
75 |
+
def ap_micro(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
76 |
+
"""
|
77 |
+
Micro AP over the entire pool (unchanged).
|
78 |
+
"""
|
79 |
+
with warnings.catch_warnings():
|
80 |
+
warnings.filterwarnings("ignore", message="No positive class found in y_true")
|
81 |
+
return average_precision_score(y_true.ravel(), y_score.ravel())
|
82 |
+
|
83 |
+
|
84 |
+
# ---------------------------
|
85 |
+
# Recall@k
|
86 |
+
# ---------------------------
|
87 |
+
|
88 |
+
|
89 |
+
def recall_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
|
90 |
+
"""
|
91 |
+
Micro Recall@k (a.k.a. instance-level recall)
|
92 |
+
|
93 |
+
– Each *positive instance* counts once, regardless of which query it belongs to.
|
94 |
+
– Denominator = total #positives across the whole pool.
|
95 |
+
"""
|
96 |
+
total_pos = y_true.sum()
|
97 |
+
if total_pos == 0:
|
98 |
+
return np.nan
|
99 |
+
|
100 |
+
topk = np.argsort(-y_score, axis=1)[:, :k] # indices of top-k per query
|
101 |
+
rows = np.arange(topk.shape[0])[:, None]
|
102 |
+
|
103 |
+
hits = (y_true[rows, topk] > 0).sum() # total #hits (instances)
|
104 |
+
return hits / total_pos
|
105 |
+
|
106 |
+
|
107 |
+
def recall_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
|
108 |
+
"""
|
109 |
+
Macro Recall@k (query-level recall)
|
110 |
+
|
111 |
+
– First compute recall per *query* (#hits / #positives in that query).
|
112 |
+
– Then average across all queries that actually contain ≥1 positive.
|
113 |
+
"""
|
114 |
+
mask = y_true.sum(axis=1) > 0 # keep only valid queries
|
115 |
+
if not mask.any():
|
116 |
+
return np.nan
|
117 |
+
|
118 |
+
Yt, Ys = y_true[mask], y_score[mask]
|
119 |
+
topk = np.argsort(-Ys, axis=1)[:, :k]
|
120 |
+
rows = np.arange(Yt.shape[0])[:, None]
|
121 |
+
|
122 |
+
hits_per_q = (Yt[rows, topk] > 0).sum(axis=1) # shape: (n_queries,)
|
123 |
+
pos_per_q = Yt.sum(axis=1)
|
124 |
+
|
125 |
+
return np.mean(hits_per_q / pos_per_q) # average of query recalls
|
126 |
+
|
127 |
+
|
128 |
+
# ---------------------------
|
129 |
+
# Precision@k
|
130 |
+
# ---------------------------
|
131 |
+
|
132 |
+
|
133 |
+
def precision_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
|
134 |
+
"""
|
135 |
+
Micro Precision@k (pool-level precision)
|
136 |
+
|
137 |
+
– Numerator = total #hits across all queries.
|
138 |
+
– Denominator = total #predictions considered (n_queries · k).
|
139 |
+
"""
|
140 |
+
if y_true.size == 0:
|
141 |
+
return np.nan
|
142 |
+
|
143 |
+
topk = np.argsort(-y_score, axis=1)[:, :k]
|
144 |
+
rows = np.arange(topk.shape[0])[:, None]
|
145 |
+
|
146 |
+
hits = (y_true[rows, topk] > 0).sum()
|
147 |
+
total_pred = y_true.shape[0] * k
|
148 |
+
return hits / total_pred
|
149 |
+
|
150 |
+
|
151 |
+
def precision_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
|
152 |
+
"""
|
153 |
+
Macro Precision@k (query-level precision)
|
154 |
+
|
155 |
+
– Compute precision = (#hits / k) for each query, **including those with zero positives**,
|
156 |
+
then average.
|
157 |
+
"""
|
158 |
+
if y_true.size == 0:
|
159 |
+
return np.nan
|
160 |
+
|
161 |
+
topk = np.argsort(-y_score, axis=1)[:, :k]
|
162 |
+
rows = np.arange(topk.shape[0])[:, None]
|
163 |
+
|
164 |
+
rel = y_true[rows, topk] > 0 # shape: (n_queries, k)
|
165 |
+
precision_per_q = rel.mean(axis=1) # mean over k positions
|
166 |
+
return precision_per_q.mean()
|
167 |
+
|
168 |
+
|
169 |
+
# helper methods
|
170 |
+
|
171 |
+
|
172 |
+
def bootstrap(
|
173 |
+
metric_fn: Callable[[np.ndarray, np.ndarray], float],
|
174 |
+
y_true: np.ndarray,
|
175 |
+
y_score: np.ndarray,
|
176 |
+
n: int = 1000,
|
177 |
+
rng=None,
|
178 |
+
) -> dict[str, float]:
|
179 |
+
rng = np.random.default_rng(rng)
|
180 |
+
idx = np.arange(len(y_true))
|
181 |
+
vals: list[float] = []
|
182 |
+
|
183 |
+
while len(vals) < n:
|
184 |
+
sample = rng.choice(idx, size=len(idx), replace=True)
|
185 |
+
t = y_true[sample]
|
186 |
+
s = y_score[sample]
|
187 |
+
if t.sum() == 0: # no positive at all → resample
|
188 |
+
continue
|
189 |
+
vals.append(metric_fn(t, s))
|
190 |
+
|
191 |
+
result = np.asarray(vals)
|
192 |
+
# get 95% confidence interval
|
193 |
+
lo, hi = np.percentile(result, [2.5, 97.5])
|
194 |
+
return {"mean": result.mean(), "low": lo, "high": hi}
|
195 |
+
|
196 |
+
|
197 |
+
def evaluate_with_ranx(
|
198 |
+
pred_rels: set[BinaryRelation],
|
199 |
+
target_rels: set[BinaryRelation],
|
200 |
+
metrics: list[str],
|
201 |
+
include_queries_without_gold: bool = True,
|
202 |
+
) -> dict[str, float]:
|
203 |
+
|
204 |
+
# lazy import to not require ranx via requirements.txt
|
205 |
+
import ranx
|
206 |
+
|
207 |
+
all_rels = set(pred_rels) | set(target_rels)
|
208 |
+
all_heads = {rel.head for rel in all_rels}
|
209 |
+
head2id = {head: f"q_{idx}" for idx, head in enumerate(sorted(all_heads))}
|
210 |
+
tail_and_label2id = {(ann.tail, ann.label): f"d_{idx}" for idx, ann in enumerate(all_rels)}
|
211 |
+
|
212 |
+
qrels_dict: dict[str, dict[str, int]] = defaultdict(dict) # {query_id: {doc_id: 1}}
|
213 |
+
run_dict: dict[str, dict[str, float]] = defaultdict(dict) # {query_id: {doc_id: score}}
|
214 |
+
|
215 |
+
for target_rel in target_rels:
|
216 |
+
query_id = head2id[target_rel.head]
|
217 |
+
doc_id = tail_and_label2id[(target_rel.tail, target_rel.label)]
|
218 |
+
if target_rel.score != 1.0:
|
219 |
+
raise ValueError(
|
220 |
+
f"target score must be 1.0, but got {target_rel.score} for {target_rel}"
|
221 |
+
)
|
222 |
+
qrels_dict[query_id][doc_id] = 1
|
223 |
+
|
224 |
+
for pred_rel in pred_rels:
|
225 |
+
query_id = head2id[pred_rel.head]
|
226 |
+
doc_id = tail_and_label2id[(pred_rel.tail, pred_rel.label)]
|
227 |
+
run_dict[query_id][doc_id] = pred_rel.score
|
228 |
+
|
229 |
+
if include_queries_without_gold:
|
230 |
+
# add missing query ids to rund_dict and qrels_dict
|
231 |
+
for query_id in set(head2id.values()) - set(qrels_dict):
|
232 |
+
qrels_dict[query_id] = {}
|
233 |
+
|
234 |
+
# evaluate
|
235 |
+
qrels = ranx.Qrels(qrels_dict)
|
236 |
+
run = ranx.Run(run_dict)
|
237 |
+
results = ranx.evaluate(qrels, run, metrics, make_comparable=True)
|
238 |
+
return results
|
239 |
+
|
240 |
+
|
241 |
+
def deduplicate_relations(
|
242 |
+
relations: Iterable[BinaryRelation], caption: str
|
243 |
+
) -> Set[BinaryRelation]:
|
244 |
+
pred2scores = defaultdict(set)
|
245 |
+
for ann in relations:
|
246 |
+
pred2scores[ann].add(round(ann.score, 4))
|
247 |
+
# warning for duplicates
|
248 |
+
preds_with_duplicates = [ann for ann, scores in pred2scores.items() if len(scores) > 1]
|
249 |
+
if len(preds_with_duplicates) > 0:
|
250 |
+
logger.warning(
|
251 |
+
f"there are {len(preds_with_duplicates)} {caption} with duplicates: "
|
252 |
+
f"{preds_with_duplicates}. We will take the max score for each annotation."
|
253 |
+
)
|
254 |
+
|
255 |
+
# take the max score for each annotation
|
256 |
+
result = {ann.copy(score=max(scores)) for ann, scores in pred2scores.items()}
|
257 |
+
return result
|
258 |
+
|
259 |
+
|
260 |
+
def construct_y_true_and_score(
|
261 |
+
preds: Iterable[BinaryRelation], targets: Iterable[BinaryRelation]
|
262 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
263 |
+
|
264 |
+
# helper constructs
|
265 |
+
all_anns = set(preds) | set(targets)
|
266 |
+
head2relations = defaultdict(list)
|
267 |
+
for ann in all_anns:
|
268 |
+
head2relations[ann.head].append(ann)
|
269 |
+
target2score = {rel: rel.score for rel in targets}
|
270 |
+
pred2score = {rel: rel.score for rel in preds}
|
271 |
+
|
272 |
+
max_len = max(len(relations) for relations in head2relations.values())
|
273 |
+
target_rows, pred_rows = [], []
|
274 |
+
for query in head2relations:
|
275 |
+
relations = head2relations[query]
|
276 |
+
# get a very small, random score for missing predictions. Or should we use 0.0 as before? or NEG_INF?
|
277 |
+
missing_pred_score = NEG_INF # np.random.uniform(0.0, 0.001) #0.0 #
|
278 |
+
missing_target_score = 0
|
279 |
+
query_scores = [
|
280 |
+
(target2score.get(ann, missing_target_score), pred2score.get(ann, missing_pred_score))
|
281 |
+
for ann in relations
|
282 |
+
]
|
283 |
+
|
284 |
+
# sort by descending order of prediction score
|
285 |
+
query_scores_sorted = np.array(sorted(query_scores, key=lambda x: x[1], reverse=True))
|
286 |
+
|
287 |
+
# pad with zeros so every row has the same length
|
288 |
+
pad_width = max_len - len(query_scores)
|
289 |
+
query_target = np.pad(
|
290 |
+
query_scores_sorted[:, 0], (0, pad_width), constant_values=missing_target_score
|
291 |
+
)
|
292 |
+
query_pred = np.pad(
|
293 |
+
query_scores_sorted[:, 1], (0, pad_width), constant_values=missing_pred_score
|
294 |
+
)
|
295 |
+
|
296 |
+
target_rows.append(query_target)
|
297 |
+
pred_rows.append(query_pred)
|
298 |
+
|
299 |
+
y_true = np.vstack(target_rows) # shape (n_queries, max_len)
|
300 |
+
y_score = np.vstack(pred_rows)
|
301 |
+
|
302 |
+
return y_true, y_score
|
303 |
+
|
304 |
+
|
305 |
+
class SemanticallySameRankingMetric(DocumentMetric):
|
306 |
+
|
307 |
+
def __init__(
|
308 |
+
self,
|
309 |
+
layer: str,
|
310 |
+
label: Optional[str] = None,
|
311 |
+
add_reversed: bool = False,
|
312 |
+
require_positive_gold: bool = False,
|
313 |
+
bootstrap_n: Optional[int] = None,
|
314 |
+
k_values: Optional[List[int]] = None,
|
315 |
+
return_coverage: bool = True,
|
316 |
+
show_as_markdown: bool = False,
|
317 |
+
use_ranx: bool = False,
|
318 |
+
add_stats_to_result: bool = False,
|
319 |
+
) -> None:
|
320 |
+
super().__init__()
|
321 |
+
self.layer = layer
|
322 |
+
self.label = label
|
323 |
+
self.add_reversed = add_reversed
|
324 |
+
self.require_positive_gold = require_positive_gold
|
325 |
+
self.bootstrap_n = bootstrap_n
|
326 |
+
self.k_values = k_values if k_values is not None else [1, 5, 10]
|
327 |
+
self.return_coverage = return_coverage
|
328 |
+
self.show_as_markdown = show_as_markdown
|
329 |
+
self.use_ranx = use_ranx
|
330 |
+
self.add_stats_to_result = add_stats_to_result
|
331 |
+
|
332 |
+
self.metrics = {
|
333 |
+
"macro_ndcg": macro_ndcg,
|
334 |
+
"macro_mrr": true_mrr,
|
335 |
+
"macro_map": macro_map,
|
336 |
+
"micro_ap": ap_micro,
|
337 |
+
}
|
338 |
+
for name, func in [
|
339 |
+
("macro_ndcg", macro_ndcg),
|
340 |
+
("micro_recall", recall_at_k_micro),
|
341 |
+
("micro_precision", precision_at_k_micro),
|
342 |
+
("macro_recall", recall_at_k_macro),
|
343 |
+
("macro_precision", precision_at_k_macro),
|
344 |
+
]:
|
345 |
+
for k in self.k_values:
|
346 |
+
self.metrics[f"{name}@{k}"] = partial(func, k=k) # type: ignore
|
347 |
+
|
348 |
+
self.ranx_metrics = ["map", "mrr", "ndcg"]
|
349 |
+
for name in ["recall", "precision", "ndcg"]:
|
350 |
+
for k in self.k_values:
|
351 |
+
self.ranx_metrics.append(f"{name}@{k}")
|
352 |
+
|
353 |
+
def reset(self) -> None:
|
354 |
+
"""
|
355 |
+
Reset the metric to its initial state.
|
356 |
+
"""
|
357 |
+
self._preds: List[BinaryRelation] = []
|
358 |
+
self._targets: List[BinaryRelation] = []
|
359 |
+
|
360 |
+
def _update(self, document):
|
361 |
+
layer = document[self.layer]
|
362 |
+
ann: BinaryRelation
|
363 |
+
for ann in layer:
|
364 |
+
if self.label is None or ann.label == self.label:
|
365 |
+
if ann.score > 0.0:
|
366 |
+
self._targets.append(ann.copy())
|
367 |
+
if self.add_reversed:
|
368 |
+
self._targets.append(ann.copy(head=ann.tail, tail=ann.head))
|
369 |
+
|
370 |
+
for ann in layer.predictions:
|
371 |
+
if self.label is None or ann.label == self.label:
|
372 |
+
if ann.score > 0.0:
|
373 |
+
self._preds.append(ann.copy())
|
374 |
+
if self.add_reversed:
|
375 |
+
self._preds.append(ann.copy(head=ann.tail, tail=ann.head))
|
376 |
+
|
377 |
+
def _compute(self):
|
378 |
+
# take the max score for each annotation
|
379 |
+
preds_deduplicated = deduplicate_relations(self._preds, "predictions")
|
380 |
+
targets_deduplicated = deduplicate_relations(self._targets, "targets")
|
381 |
+
|
382 |
+
stats = {
|
383 |
+
"gold": len(targets_deduplicated),
|
384 |
+
"preds": len(preds_deduplicated),
|
385 |
+
"queries": len(
|
386 |
+
set(ann.head for ann in targets_deduplicated)
|
387 |
+
| set(ann.head for ann in preds_deduplicated)
|
388 |
+
),
|
389 |
+
}
|
390 |
+
|
391 |
+
if self.use_ranx:
|
392 |
+
if self.bootstrap_n is not None:
|
393 |
+
raise ValueError(
|
394 |
+
"Ranx does not support bootstrapping. Please set bootstrap_n=None."
|
395 |
+
)
|
396 |
+
|
397 |
+
scores = evaluate_with_ranx(
|
398 |
+
preds_deduplicated,
|
399 |
+
targets_deduplicated,
|
400 |
+
metrics=self.ranx_metrics,
|
401 |
+
include_queries_without_gold=not self.require_positive_gold,
|
402 |
+
)
|
403 |
+
if self.add_stats_to_result:
|
404 |
+
scores.update(stats)
|
405 |
+
# logger.info(f"results via ranx:\n{pd.Series(ranx_result).sort_index().round(3).to_markdown()}")
|
406 |
+
df = pd.DataFrame.from_records([scores], index=["score"])
|
407 |
+
else:
|
408 |
+
|
409 |
+
y_true, y_score = construct_y_true_and_score(
|
410 |
+
preds=preds_deduplicated, targets=targets_deduplicated
|
411 |
+
)
|
412 |
+
|
413 |
+
# original definition ─ share of queries with ≥1 positive
|
414 |
+
coverage = (y_true.sum(axis=1) > 0).mean()
|
415 |
+
|
416 |
+
# keep only queries that actually have at least one gold positive
|
417 |
+
if self.require_positive_gold:
|
418 |
+
mask = y_true.sum(axis=1) > 0 # shape: (n_queries,)
|
419 |
+
y_true = y_true[mask]
|
420 |
+
y_score = y_score[mask]
|
421 |
+
|
422 |
+
if self.bootstrap_n is not None:
|
423 |
+
scores = {
|
424 |
+
name: bootstrap(fn, y_true, y_score, n=self.bootstrap_n)
|
425 |
+
for name, fn in self.metrics.items()
|
426 |
+
}
|
427 |
+
if self.add_stats_to_result:
|
428 |
+
scores["stats"] = stats
|
429 |
+
df = pd.DataFrame(scores)
|
430 |
+
else:
|
431 |
+
scores = {name: fn(y_true, y_score) for name, fn in self.metrics.items()}
|
432 |
+
if self.add_stats_to_result:
|
433 |
+
scores.update(stats)
|
434 |
+
df = pd.DataFrame.from_records([scores], index=["score"])
|
435 |
+
|
436 |
+
if self.return_coverage:
|
437 |
+
scores["coverage"] = coverage
|
438 |
+
|
439 |
+
if self.show_as_markdown:
|
440 |
+
if not self.add_stats_to_result:
|
441 |
+
logger.info(
|
442 |
+
logger.info(
|
443 |
+
f'\nstatistics ({self.layer}):\n{pd.Series(stats, name="value").to_markdown()}'
|
444 |
+
)
|
445 |
+
)
|
446 |
+
logger.info(f"\n{self.layer}:\n{df.round(4).T.to_markdown()}")
|
447 |
+
|
448 |
+
return scores
|
src/metrics/tpfpfn.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import defaultdict
|
3 |
+
from functools import partial
|
4 |
+
from typing import (
|
5 |
+
Any,
|
6 |
+
Callable,
|
7 |
+
Collection,
|
8 |
+
Dict,
|
9 |
+
Hashable,
|
10 |
+
List,
|
11 |
+
Optional,
|
12 |
+
Tuple,
|
13 |
+
TypeAlias,
|
14 |
+
Union,
|
15 |
+
)
|
16 |
+
|
17 |
+
from pytorch_ie.core import Annotation, Document, DocumentMetric
|
18 |
+
from pytorch_ie.utils.hydra import resolve_target
|
19 |
+
|
20 |
+
from src.document.types import RelatedRelation
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
|
26 |
+
return getattr(ann, label_field) in labels
|
27 |
+
|
28 |
+
|
29 |
+
def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
|
30 |
+
return getattr(ann, label_field) == label
|
31 |
+
|
32 |
+
|
33 |
+
InstanceType: TypeAlias = Tuple[Document, Annotation]
|
34 |
+
InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]]
|
35 |
+
|
36 |
+
|
37 |
+
class TPFFPFNMetric(DocumentMetric):
|
38 |
+
"""Computes the lists of True Positive, False Positive, and False Negative
|
39 |
+
annotations for a given layer. If labels are provided, it also computes
|
40 |
+
the counts for each label separately.
|
41 |
+
|
42 |
+
Works only with `RelatedRelation` annotations for now.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
layer: The layer to compute the metrics for.
|
46 |
+
labels: If provided, calculate metrics for each label.
|
47 |
+
label_field: The field to use for the label. Defaults to "label".
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
layer: str,
|
53 |
+
labels: Optional[Union[Collection[str], str]] = None,
|
54 |
+
label_field: str = "label",
|
55 |
+
annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.layer = layer
|
59 |
+
self.label_field = label_field
|
60 |
+
self.annotation_processor: Optional[Callable[[Annotation], Hashable]]
|
61 |
+
if isinstance(annotation_processor, str):
|
62 |
+
self.annotation_processor = resolve_target(annotation_processor)
|
63 |
+
else:
|
64 |
+
self.annotation_processor = annotation_processor
|
65 |
+
|
66 |
+
self.per_label = labels is not None
|
67 |
+
self.infer_labels = False
|
68 |
+
if self.per_label:
|
69 |
+
if isinstance(labels, str):
|
70 |
+
if labels != "INFERRED":
|
71 |
+
raise ValueError(
|
72 |
+
"labels can only be 'INFERRED' if per_label is True and labels is a string"
|
73 |
+
)
|
74 |
+
self.labels = []
|
75 |
+
self.infer_labels = True
|
76 |
+
elif isinstance(labels, Collection):
|
77 |
+
if not all(isinstance(label, str) for label in labels):
|
78 |
+
raise ValueError("labels must be a collection of strings")
|
79 |
+
if "MICRO" in labels or "MACRO" in labels:
|
80 |
+
raise ValueError(
|
81 |
+
"labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
|
82 |
+
)
|
83 |
+
if len(labels) == 0:
|
84 |
+
raise ValueError("labels cannot be empty")
|
85 |
+
self.labels = list(labels)
|
86 |
+
else:
|
87 |
+
raise ValueError("labels must be a string or a collection of strings")
|
88 |
+
|
89 |
+
def reset(self):
|
90 |
+
self.tp_fp_fn = defaultdict(lambda: (list(), list(), list()))
|
91 |
+
|
92 |
+
def get_tp_fp_fn(
|
93 |
+
self,
|
94 |
+
document: Document,
|
95 |
+
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
|
96 |
+
annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
|
97 |
+
) -> InstancesType:
|
98 |
+
annotation_processor = annotation_processor or (lambda ann: ann)
|
99 |
+
annotation_filter = annotation_filter or (lambda ann: True)
|
100 |
+
predicted_annotations = {
|
101 |
+
annotation_processor(ann)
|
102 |
+
for ann in document[self.layer].predictions
|
103 |
+
if annotation_filter(ann)
|
104 |
+
}
|
105 |
+
gold_annotations = {
|
106 |
+
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
|
107 |
+
}
|
108 |
+
tp = [(document, ann) for ann in predicted_annotations & gold_annotations]
|
109 |
+
fn = [(document, ann) for ann in gold_annotations - predicted_annotations]
|
110 |
+
fp = [(document, ann) for ann in predicted_annotations - gold_annotations]
|
111 |
+
return tp, fp, fn
|
112 |
+
|
113 |
+
def add_annotations(self, annotations: InstancesType, label: str):
|
114 |
+
self.tp_fp_fn[label] = (
|
115 |
+
self.tp_fp_fn[label][0] + annotations[0],
|
116 |
+
self.tp_fp_fn[label][1] + annotations[1],
|
117 |
+
self.tp_fp_fn[label][2] + annotations[2],
|
118 |
+
)
|
119 |
+
|
120 |
+
def _update(self, document: Document):
|
121 |
+
new_tp_fp_fn = self.get_tp_fp_fn(
|
122 |
+
document=document,
|
123 |
+
annotation_filter=(
|
124 |
+
partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
|
125 |
+
if self.per_label and not self.infer_labels
|
126 |
+
else None
|
127 |
+
),
|
128 |
+
annotation_processor=self.annotation_processor,
|
129 |
+
)
|
130 |
+
self.add_annotations(new_tp_fp_fn, label="MICRO")
|
131 |
+
if self.infer_labels:
|
132 |
+
layer = document[self.layer]
|
133 |
+
# collect labels from gold data and predictions
|
134 |
+
for ann in list(layer) + list(layer.predictions):
|
135 |
+
label = getattr(ann, self.label_field)
|
136 |
+
if label not in self.labels:
|
137 |
+
self.labels.append(label)
|
138 |
+
if self.per_label:
|
139 |
+
for label in self.labels:
|
140 |
+
new_tp_fp_fn = self.get_tp_fp_fn(
|
141 |
+
document=document,
|
142 |
+
annotation_filter=partial(
|
143 |
+
has_this_label, label_field=self.label_field, label=label
|
144 |
+
),
|
145 |
+
annotation_processor=self.annotation_processor,
|
146 |
+
)
|
147 |
+
self.add_annotations(new_tp_fp_fn, label=label)
|
148 |
+
|
149 |
+
def format_texts(self, texts: List[str]) -> str:
|
150 |
+
return "<SEP>".join(texts)
|
151 |
+
|
152 |
+
def format_annotation(self, ann: Annotation) -> Dict[str, Any]:
|
153 |
+
if isinstance(ann, RelatedRelation):
|
154 |
+
|
155 |
+
head_resolved = ann.head.resolve()
|
156 |
+
tail_resolved = ann.tail.resolve()
|
157 |
+
ref_resolved = ann.reference_span.resolve()
|
158 |
+
return {
|
159 |
+
"related_label": ann.label,
|
160 |
+
"related_score": round(ann.score, 3),
|
161 |
+
"query_label": head_resolved[0],
|
162 |
+
"query_texts": self.format_texts(head_resolved[1]),
|
163 |
+
"query_score": round(ann.head.score, 3),
|
164 |
+
"ref_label": ref_resolved[0],
|
165 |
+
"ref_texts": self.format_texts(ref_resolved[1]),
|
166 |
+
"ref_score": round(ann.reference_span.score, 3),
|
167 |
+
"rec_label": tail_resolved[0],
|
168 |
+
"rec_texts": self.format_texts(tail_resolved[1]),
|
169 |
+
"rec_score": round(ann.tail.score, 3),
|
170 |
+
}
|
171 |
+
else:
|
172 |
+
raise NotImplementedError
|
173 |
+
# return ann.resolve()
|
174 |
+
|
175 |
+
def format_instance(self, instance: InstanceType) -> Dict[str, Any]:
|
176 |
+
document, annotation = instance
|
177 |
+
result = self.format_annotation(annotation)
|
178 |
+
if getattr(document, "id", None) is not None:
|
179 |
+
result["document_id"] = document.id
|
180 |
+
return result
|
181 |
+
|
182 |
+
def _compute(self) -> Dict[str, Dict[str, list]]:
|
183 |
+
res = dict()
|
184 |
+
for k, instances in self.tp_fp_fn.items():
|
185 |
+
res[k] = {
|
186 |
+
"tp": [self.format_instance(instance) for instance in instances[0]],
|
187 |
+
"fp": [self.format_instance(instance) for instance in instances[1]],
|
188 |
+
"fn": [self.format_instance(instance) for instance in instances[2]],
|
189 |
+
}
|
190 |
+
|
191 |
+
# if self.show_as_markdown:
|
192 |
+
# logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
|
193 |
+
return res
|
src/models/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
|
2 |
from .sequence_classification_with_pooler import (
|
|
|
3 |
SequencePairSimilarityModelWithMaxCosineSim,
|
4 |
-
|
5 |
SequencePairSimilarityModelWithPoolerAndAdapter,
|
6 |
)
|
|
|
1 |
from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
|
2 |
from .sequence_classification_with_pooler import (
|
3 |
+
SequencePairSimilarityModelDummy,
|
4 |
SequencePairSimilarityModelWithMaxCosineSim,
|
5 |
+
SequencePairSimilarityModelWithMaxCosineSimAndAdapter,
|
6 |
SequencePairSimilarityModelWithPoolerAndAdapter,
|
7 |
)
|
src/models/sequence_classification_with_pooler.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import abc
|
2 |
import logging
|
3 |
-
from typing import
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from adapters import AutoAdapterModel
|
8 |
from pie_modules.models import SequencePairSimilarityModelWithPooler
|
9 |
-
from pie_modules.models.components.pooler import MENTION_POOLING
|
10 |
from pie_modules.models.sequence_classification_with_pooler import (
|
11 |
InputType,
|
12 |
OutputType,
|
@@ -20,31 +19,11 @@ from torch import FloatTensor, Tensor
|
|
20 |
from transformers import AutoConfig, PreTrainedModel
|
21 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
22 |
|
23 |
-
from src.models.components.pooler import SpanMeanPooler
|
24 |
-
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
|
28 |
-
class SequenceClassificationModelWithPoolerBase2(
|
29 |
-
SequenceClassificationModelWithPoolerBase, abc.ABC
|
30 |
-
):
|
31 |
-
def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
|
32 |
-
aggregate = self.pooler_config.get("aggregate", "max")
|
33 |
-
if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
|
34 |
-
if aggregate == "mean":
|
35 |
-
pooler_config = dict(self.pooler_config)
|
36 |
-
pooler_config.pop("type")
|
37 |
-
pooler_config.pop("aggregate")
|
38 |
-
pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
|
39 |
-
return pooler, pooler.output_dim
|
40 |
-
else:
|
41 |
-
raise ValueError(f"Unknown aggregation method: {aggregate}")
|
42 |
-
else:
|
43 |
-
return super().setup_pooler(input_dim)
|
44 |
-
|
45 |
-
|
46 |
class SequenceClassificationModelWithPoolerAndAdapterBase(
|
47 |
-
|
48 |
):
|
49 |
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
|
50 |
self.adapter_name_or_path = adapter_name_or_path
|
@@ -66,13 +45,6 @@ class SequenceClassificationModelWithPoolerAndAdapterBase(
|
|
66 |
return model
|
67 |
|
68 |
|
69 |
-
@PyTorchIEModel.register()
|
70 |
-
class SequencePairSimilarityModelWithPooler2(
|
71 |
-
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
|
72 |
-
):
|
73 |
-
pass
|
74 |
-
|
75 |
-
|
76 |
@PyTorchIEModel.register()
|
77 |
class SequencePairSimilarityModelWithPoolerAndAdapter(
|
78 |
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
|
@@ -164,3 +136,66 @@ class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
|
|
164 |
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
|
165 |
):
|
166 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import abc
|
2 |
import logging
|
3 |
+
from typing import Callable, List, Optional
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from adapters import AutoAdapterModel
|
8 |
from pie_modules.models import SequencePairSimilarityModelWithPooler
|
|
|
9 |
from pie_modules.models.sequence_classification_with_pooler import (
|
10 |
InputType,
|
11 |
OutputType,
|
|
|
19 |
from transformers import AutoConfig, PreTrainedModel
|
20 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
21 |
|
|
|
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class SequenceClassificationModelWithPoolerAndAdapterBase(
|
26 |
+
SequenceClassificationModelWithPoolerBase, abc.ABC
|
27 |
):
|
28 |
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
|
29 |
self.adapter_name_or_path = adapter_name_or_path
|
|
|
45 |
return model
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
@PyTorchIEModel.register()
|
49 |
class SequencePairSimilarityModelWithPoolerAndAdapter(
|
50 |
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
|
|
|
136 |
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
|
137 |
):
|
138 |
pass
|
139 |
+
|
140 |
+
|
141 |
+
@PyTorchIEModel.register()
|
142 |
+
class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler):
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
method: str = "random",
|
147 |
+
random_seed: Optional[int] = None,
|
148 |
+
**kwargs,
|
149 |
+
):
|
150 |
+
self.method = method
|
151 |
+
self.random_seed = random_seed
|
152 |
+
super().__init__(**kwargs)
|
153 |
+
|
154 |
+
def setup_classifier(
|
155 |
+
self, pooler_output_dim: int
|
156 |
+
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
|
157 |
+
if self.method == "random":
|
158 |
+
generator = torch.Generator(device=self.device)
|
159 |
+
if self.random_seed is not None:
|
160 |
+
generator = generator.manual_seed(self.random_seed)
|
161 |
+
|
162 |
+
def binary_classify_random(
|
163 |
+
inputs: torch.FloatTensor,
|
164 |
+
inputs_pair: torch.FloatTensor,
|
165 |
+
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
|
166 |
+
"""Randomly classifies pairs of inputs as similar or not similar."""
|
167 |
+
# Generate random logits in the range of [0, 1]
|
168 |
+
logits = torch.rand(inputs.size(0), device=self.device, generator=generator)
|
169 |
+
return logits
|
170 |
+
|
171 |
+
return binary_classify_random
|
172 |
+
elif self.method == "zero":
|
173 |
+
|
174 |
+
def binary_classify_zero(
|
175 |
+
inputs: torch.FloatTensor,
|
176 |
+
inputs_pair: torch.FloatTensor,
|
177 |
+
) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
|
178 |
+
"""Classifies pairs of inputs as not similar (logit = 0)."""
|
179 |
+
# Return a tensor of zeros with the same batch size
|
180 |
+
logits = torch.zeros(inputs.size(0), device=self.device)
|
181 |
+
return logits
|
182 |
+
|
183 |
+
return binary_classify_zero
|
184 |
+
else:
|
185 |
+
raise ValueError(
|
186 |
+
f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'."
|
187 |
+
)
|
188 |
+
|
189 |
+
def setup_loss_fct(self) -> Callable:
|
190 |
+
def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor:
|
191 |
+
raise NotImplementedError(
|
192 |
+
"Dummy model does not support loss function, as it is not used for training."
|
193 |
+
)
|
194 |
+
|
195 |
+
return loss_fct
|
196 |
+
|
197 |
+
def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
|
198 |
+
# Just return a tensor of zeros in the shape of the batch size
|
199 |
+
# so that the classifier can construct dummy logits in the correct shape.
|
200 |
+
bs = pooler_inputs["start_indices"].size(0)
|
201 |
+
return torch.zeros(bs, device=self.device)
|
src/predict.py
CHANGED
@@ -113,8 +113,8 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
113 |
.to(dtype=pipeline.model.dtype)
|
114 |
)
|
115 |
|
116 |
-
# auto-convert the dataset if the
|
117 |
-
dataset = pipeline.taskmodule
|
118 |
|
119 |
# Init the serializer
|
120 |
serializer: Optional[DocumentSerializer] = None
|
|
|
113 |
.to(dtype=pipeline.model.dtype)
|
114 |
)
|
115 |
|
116 |
+
# auto-convert the dataset if the taskmodule specifies a document type
|
117 |
+
dataset = dataset.to_document_type(pipeline.taskmodule, downcast=False)
|
118 |
|
119 |
# Init the serializer
|
120 |
serializer: Optional[DocumentSerializer] = None
|
src/serializer/__init__.py
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
from .json import JsonSerializer
|
|
|
|
|
|
|
|
1 |
+
from .json import JsonSerializer
|
2 |
+
|
3 |
+
# backward compatibility
|
4 |
+
JsonSerializer2 = JsonSerializer
|
src/serializer/interface.py
CHANGED
@@ -12,5 +12,4 @@ class DocumentSerializer(ABC):
|
|
12 |
"""
|
13 |
|
14 |
@abstractmethod
|
15 |
-
def __call__(self, documents: Iterable[Document]) -> Any:
|
16 |
-
pass
|
|
|
12 |
"""
|
13 |
|
14 |
@abstractmethod
|
15 |
+
def __call__(self, documents: Iterable[Document], append: bool = False, **kwargs) -> Any: ...
|
|
src/serializer/json.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
-
import
|
2 |
-
import os
|
3 |
-
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
|
4 |
|
5 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
6 |
-
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
7 |
from pytorch_ie.core import Document
|
8 |
-
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
9 |
|
10 |
from src.serializer.interface import DocumentSerializer
|
11 |
from src.utils.logging_utils import get_pylogger
|
@@ -28,125 +24,13 @@ class JsonSerializer(DocumentSerializer):
|
|
28 |
def __init__(self, **kwargs):
|
29 |
self.default_kwargs = kwargs
|
30 |
|
31 |
-
@classmethod
|
32 |
-
def write(
|
33 |
-
cls,
|
34 |
-
documents: Iterable[Document],
|
35 |
-
path: str,
|
36 |
-
file_name: str = "documents.jsonl",
|
37 |
-
metadata_file_name: str = METADATA_FILE_NAME,
|
38 |
-
split: Optional[str] = None,
|
39 |
-
**kwargs,
|
40 |
-
) -> Dict[str, str]:
|
41 |
-
realpath = os.path.realpath(path)
|
42 |
-
log.info(f'serialize documents to "{realpath}" ...')
|
43 |
-
os.makedirs(realpath, exist_ok=True)
|
44 |
-
|
45 |
-
if not isinstance(documents, Sequence):
|
46 |
-
documents = list(documents)
|
47 |
-
|
48 |
-
# dump metadata including the document_type
|
49 |
-
if len(documents) == 0:
|
50 |
-
raise Exception("cannot serialize empty list of documents")
|
51 |
-
document_type = type(documents[0])
|
52 |
-
metadata = {"document_type": serialize_document_type(document_type)}
|
53 |
-
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
|
54 |
-
if os.path.exists(full_metadata_file_name):
|
55 |
-
# load previous metadata
|
56 |
-
with open(full_metadata_file_name) as f:
|
57 |
-
previous_metadata = json.load(f)
|
58 |
-
if previous_metadata != metadata:
|
59 |
-
raise ValueError(
|
60 |
-
f"metadata file {full_metadata_file_name} already exists, "
|
61 |
-
"but the content does not match the current metadata"
|
62 |
-
"\nprevious metadata: {previous_metadata}"
|
63 |
-
"\ncurrent metadata: {metadata}"
|
64 |
-
)
|
65 |
-
else:
|
66 |
-
with open(full_metadata_file_name, "w") as f:
|
67 |
-
json.dump(metadata, f, indent=2)
|
68 |
-
|
69 |
-
if split is not None:
|
70 |
-
realpath = os.path.join(realpath, split)
|
71 |
-
os.makedirs(realpath, exist_ok=True)
|
72 |
-
full_file_name = os.path.join(realpath, file_name)
|
73 |
-
if as_json_lines(file_name):
|
74 |
-
# if the file already exists, append to it
|
75 |
-
mode = "a" if os.path.exists(full_file_name) else "w"
|
76 |
-
with open(full_file_name, mode) as f:
|
77 |
-
for doc in documents:
|
78 |
-
f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
|
79 |
-
else:
|
80 |
-
docs_list = [doc.asdict() for doc in documents]
|
81 |
-
if os.path.exists(full_file_name):
|
82 |
-
# load previous documents
|
83 |
-
with open(full_file_name) as f:
|
84 |
-
previous_doc_list = json.load(f)
|
85 |
-
docs_list = previous_doc_list + docs_list
|
86 |
-
with open(full_file_name, "w") as f:
|
87 |
-
json.dump(docs_list, fp=f, **kwargs)
|
88 |
-
return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
|
89 |
-
|
90 |
-
@classmethod
|
91 |
-
def read(
|
92 |
-
cls,
|
93 |
-
path: str,
|
94 |
-
document_type: Optional[Type[D]] = None,
|
95 |
-
file_name: str = "documents.jsonl",
|
96 |
-
metadata_file_name: str = METADATA_FILE_NAME,
|
97 |
-
split: Optional[str] = None,
|
98 |
-
) -> List[D]:
|
99 |
-
realpath = os.path.realpath(path)
|
100 |
-
log.info(f'load documents from "{realpath}" ...')
|
101 |
-
|
102 |
-
# try to load metadata including the document_type
|
103 |
-
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
|
104 |
-
if os.path.exists(full_metadata_file_name):
|
105 |
-
with open(full_metadata_file_name) as f:
|
106 |
-
metadata = json.load(f)
|
107 |
-
document_type = resolve_optional_document_type(metadata.get("document_type"))
|
108 |
-
|
109 |
-
if document_type is None:
|
110 |
-
raise Exception("document_type is required to load serialized documents")
|
111 |
-
|
112 |
-
if split is not None:
|
113 |
-
realpath = os.path.join(realpath, split)
|
114 |
-
full_file_name = os.path.join(realpath, file_name)
|
115 |
-
documents = []
|
116 |
-
if as_json_lines(str(file_name)):
|
117 |
-
with open(full_file_name) as f:
|
118 |
-
for line in f:
|
119 |
-
json_dict = json.loads(line)
|
120 |
-
documents.append(document_type.fromdict(json_dict))
|
121 |
-
else:
|
122 |
-
with open(full_file_name) as f:
|
123 |
-
json_list = json.load(f)
|
124 |
-
for json_dict in json_list:
|
125 |
-
documents.append(document_type.fromdict(json_dict))
|
126 |
-
return documents
|
127 |
-
|
128 |
-
def read_with_defaults(self, **kwargs) -> List[D]:
|
129 |
-
all_kwargs = {**self.default_kwargs, **kwargs}
|
130 |
-
return self.read(**all_kwargs)
|
131 |
-
|
132 |
-
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
|
133 |
-
all_kwargs = {**self.default_kwargs, **kwargs}
|
134 |
-
return self.write(**all_kwargs)
|
135 |
-
|
136 |
-
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
|
137 |
-
return self.write_with_defaults(documents=documents, **kwargs)
|
138 |
-
|
139 |
-
|
140 |
-
class JsonSerializer2(DocumentSerializer):
|
141 |
-
def __init__(self, **kwargs):
|
142 |
-
self.default_kwargs = kwargs
|
143 |
-
|
144 |
@classmethod
|
145 |
def write(
|
146 |
cls,
|
147 |
documents: Iterable[Document],
|
148 |
path: str,
|
149 |
split: str = "train",
|
|
|
150 |
) -> Dict[str, str]:
|
151 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
152 |
if not isinstance(documents, Sequence):
|
@@ -154,7 +38,7 @@ class JsonSerializer2(DocumentSerializer):
|
|
154 |
else:
|
155 |
documents = Dataset.from_documents(documents)
|
156 |
dataset_dict = DatasetDict({split: documents})
|
157 |
-
dataset_dict.to_json(path=path)
|
158 |
return {"path": path, "split": split}
|
159 |
|
160 |
@classmethod
|
@@ -181,5 +65,7 @@ class JsonSerializer2(DocumentSerializer):
|
|
181 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
182 |
return self.write(**all_kwargs)
|
183 |
|
184 |
-
def __call__(
|
185 |
-
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterable, Optional, Sequence, Type, TypeVar
|
|
|
|
|
2 |
|
3 |
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
|
|
4 |
from pytorch_ie.core import Document
|
|
|
5 |
|
6 |
from src.serializer.interface import DocumentSerializer
|
7 |
from src.utils.logging_utils import get_pylogger
|
|
|
24 |
def __init__(self, **kwargs):
|
25 |
self.default_kwargs = kwargs
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@classmethod
|
28 |
def write(
|
29 |
cls,
|
30 |
documents: Iterable[Document],
|
31 |
path: str,
|
32 |
split: str = "train",
|
33 |
+
append: bool = False,
|
34 |
) -> Dict[str, str]:
|
35 |
if not isinstance(documents, (Dataset, IterableDataset)):
|
36 |
if not isinstance(documents, Sequence):
|
|
|
38 |
else:
|
39 |
documents = Dataset.from_documents(documents)
|
40 |
dataset_dict = DatasetDict({split: documents})
|
41 |
+
dataset_dict.to_json(path=path, mode="a" if append else "w")
|
42 |
return {"path": path, "split": split}
|
43 |
|
44 |
@classmethod
|
|
|
65 |
all_kwargs = {**self.default_kwargs, **kwargs}
|
66 |
return self.write(**all_kwargs)
|
67 |
|
68 |
+
def __call__(
|
69 |
+
self, documents: Iterable[Document], append: bool = False, **kwargs
|
70 |
+
) -> Dict[str, str]:
|
71 |
+
return self.write_with_defaults(documents=documents, append=append, **kwargs)
|
src/start_demo.py
CHANGED
@@ -331,8 +331,9 @@ def main(cfg: DictConfig) -> None:
|
|
331 |
visible=pdf_fulltext_extractor is not None,
|
332 |
)
|
333 |
|
334 |
-
enable_acl_venue_loading =
|
335 |
-
|
|
|
336 |
)
|
337 |
acl_anthology_venues = gr.Textbox(
|
338 |
label="ACL Anthology Venues",
|
|
|
331 |
visible=pdf_fulltext_extractor is not None,
|
332 |
)
|
333 |
|
334 |
+
enable_acl_venue_loading = (
|
335 |
+
pdf_fulltext_extractor is not None
|
336 |
+
and cfg.get("acl_anthology_data_dir") is not None
|
337 |
)
|
338 |
acl_anthology_venues = gr.Textbox(
|
339 |
label="ACL Anthology Venues",
|
src/train.py
CHANGED
@@ -45,7 +45,7 @@ from pie_modules.models import SimpleGenerativeModel
|
|
45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
46 |
from pie_modules.taskmodules import * # noqa: F403
|
47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
48 |
-
from pytorch_ie import Pipeline
|
49 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
50 |
from pytorch_ie.models import * # noqa: F403
|
51 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
@@ -55,7 +55,6 @@ from pytorch_lightning import Callback, Trainer
|
|
55 |
from pytorch_lightning.loggers import Logger
|
56 |
|
57 |
from src import utils
|
58 |
-
from src.datamodules import PieDataModule
|
59 |
from src.models import * # noqa: F403
|
60 |
from src.serializer.interface import DocumentSerializer
|
61 |
from src.taskmodules import * # noqa: F403
|
@@ -135,7 +134,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
135 |
)
|
136 |
|
137 |
# auto-convert the dataset if the taskmodule specifies a document type
|
138 |
-
dataset =
|
139 |
|
140 |
# Init pytorch-ie datamodule
|
141 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
|
|
45 |
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
46 |
from pie_modules.taskmodules import * # noqa: F403
|
47 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
48 |
+
from pytorch_ie import PieDataModule, Pipeline
|
49 |
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
50 |
from pytorch_ie.models import * # noqa: F403
|
51 |
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
|
|
55 |
from pytorch_lightning.loggers import Logger
|
56 |
|
57 |
from src import utils
|
|
|
58 |
from src.models import * # noqa: F403
|
59 |
from src.serializer.interface import DocumentSerializer
|
60 |
from src.taskmodules import * # noqa: F403
|
|
|
134 |
)
|
135 |
|
136 |
# auto-convert the dataset if the taskmodule specifies a document type
|
137 |
+
dataset = dataset.to_document_type(taskmodule, downcast=False)
|
138 |
|
139 |
# Init pytorch-ie datamodule
|
140 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
src/utils/graph_utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Hashable, List, Optional, Sequence, TypeVar
|
2 |
+
|
3 |
+
from pytorch_ie.annotations import BinaryRelation
|
4 |
+
|
5 |
+
H = TypeVar("H", bound=Hashable)
|
6 |
+
|
7 |
+
|
8 |
+
def get_connected_components(
|
9 |
+
relations: Sequence[BinaryRelation],
|
10 |
+
elements: Optional[Sequence[H]] = None,
|
11 |
+
link_relation_label: Optional[str] = None,
|
12 |
+
link_relation_relation_score_threshold: Optional[float] = None,
|
13 |
+
add_singletons: bool = False,
|
14 |
+
) -> List[List[H]]:
|
15 |
+
try:
|
16 |
+
import networkx as nx
|
17 |
+
except ImportError:
|
18 |
+
raise ImportError(
|
19 |
+
"NetworkX must be installed to use the SpansViaRelationMerger. "
|
20 |
+
"You can install NetworkX with `pip install networkx`."
|
21 |
+
)
|
22 |
+
|
23 |
+
# convert list of relations to a graph to easily calculate connected components to merge
|
24 |
+
g = nx.Graph()
|
25 |
+
link_relations = []
|
26 |
+
other_relations = []
|
27 |
+
elem2edge_relation = {}
|
28 |
+
for rel in relations:
|
29 |
+
if (link_relation_label is None or rel.label == link_relation_label) and (
|
30 |
+
link_relation_relation_score_threshold is None
|
31 |
+
or rel.score >= link_relation_relation_score_threshold
|
32 |
+
):
|
33 |
+
link_relations.append(rel)
|
34 |
+
g.add_edge(rel.head, rel.tail)
|
35 |
+
elem2edge_relation[rel.head] = rel
|
36 |
+
elem2edge_relation[rel.tail] = rel
|
37 |
+
else:
|
38 |
+
other_relations.append(rel)
|
39 |
+
|
40 |
+
if add_singletons:
|
41 |
+
if elements is None:
|
42 |
+
raise ValueError("elements must be provided if add_singletons is True")
|
43 |
+
# add singletons to the graph
|
44 |
+
for elem in elements:
|
45 |
+
if elem not in elem2edge_relation:
|
46 |
+
g.add_node(elem)
|
47 |
+
return list(nx.connected_components(g))
|
src/utils/inference_utils.py
CHANGED
@@ -50,6 +50,8 @@ def predict_and_serialize(
|
|
50 |
batch_iter = [dataset]
|
51 |
else:
|
52 |
batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
|
|
|
|
|
53 |
for docs_batch in batch_iter:
|
54 |
if pipeline is not None:
|
55 |
t_start = timeit.default_timer()
|
@@ -60,13 +62,14 @@ def predict_and_serialize(
|
|
60 |
if serializer is not None:
|
61 |
# the serializer should not return the serialized documents, but write them to disk
|
62 |
# and instead return some metadata such as the path to the serialized documents
|
63 |
-
serializer_result = serializer(docs_batch)
|
64 |
if "serializer" in result and result["serializer"] != serializer_result:
|
65 |
log.warning(
|
66 |
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
67 |
" during prediction. Only the last result is returned."
|
68 |
)
|
69 |
result["serializer"] = serializer_result
|
|
|
70 |
|
71 |
if prediction_time is not None:
|
72 |
result["prediction_time"] = prediction_time
|
|
|
50 |
batch_iter = [dataset]
|
51 |
else:
|
52 |
batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
|
53 |
+
|
54 |
+
append = False
|
55 |
for docs_batch in batch_iter:
|
56 |
if pipeline is not None:
|
57 |
t_start = timeit.default_timer()
|
|
|
62 |
if serializer is not None:
|
63 |
# the serializer should not return the serialized documents, but write them to disk
|
64 |
# and instead return some metadata such as the path to the serialized documents
|
65 |
+
serializer_result = serializer(docs_batch, append=append)
|
66 |
if "serializer" in result and result["serializer"] != serializer_result:
|
67 |
log.warning(
|
68 |
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
69 |
" during prediction. Only the last result is returned."
|
70 |
)
|
71 |
result["serializer"] = serializer_result
|
72 |
+
append = True
|
73 |
|
74 |
if prediction_time is not None:
|
75 |
result["prediction_time"] = prediction_time
|
src/utils/pdf_utils/process_pdf.py
CHANGED
@@ -138,7 +138,7 @@ def process_pdf_file(
|
|
138 |
os.makedirs(output_dir, exist_ok=True)
|
139 |
|
140 |
# get paper id as the name of the file
|
141 |
-
paper_id =
|
142 |
tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
|
143 |
output_file = os.path.join(output_dir, f"{paper_id}.json")
|
144 |
|
|
|
138 |
os.makedirs(output_dir, exist_ok=True)
|
139 |
|
140 |
# get paper id as the name of the file
|
141 |
+
paper_id = os.path.splitext(os.path.basename(input_file))[0]
|
142 |
tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
|
143 |
output_file = os.path.join(output_dir, f"{paper_id}.json")
|
144 |
|