ArneBinder commited on
Commit
d868d2e
·
verified ·
1 Parent(s): 5fbc03a

update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529

Browse files
Files changed (39) hide show
  1. src/analysis/combine_job_returns.py +207 -21
  2. src/analysis/common.py +27 -6
  3. src/analysis/compare_job_returns.py +1 -1
  4. src/analysis/format_metric_results.py +269 -0
  5. src/analysis/get_json_field_as_string.py +55 -0
  6. src/analysis/show_inference_params_on_quality_and_throughput.py +485 -0
  7. src/datamodules/__init__.py +1 -1
  8. src/datamodules/datamodule_with_sampler.py +59 -0
  9. src/dataset/processing.py +88 -3
  10. src/demo/annotation_utils.py +6 -56
  11. src/demo/backend_utils.py +50 -12
  12. src/demo/retrieve_and_dump_all_relevant.py +82 -38
  13. src/document/processing.py +300 -1
  14. src/evaluate.py +3 -3
  15. src/evaluate_documents.py +1 -1
  16. src/hydra_callbacks/save_job_return_value.py +67 -4
  17. src/langchain_modules/basic_pie_document_store.py +3 -1
  18. src/langchain_modules/datasets_pie_document_store.py +1 -1
  19. src/metrics/__init__.py +7 -1
  20. src/metrics/connected_component_sizes.py +43 -0
  21. src/metrics/coref.py +223 -0
  22. src/metrics/coref_sklearn.py +158 -43
  23. src/metrics/f1_with_bootstrapping.py +103 -0
  24. src/metrics/f1_with_threshold.py +33 -0
  25. src/metrics/ranking_sklearn.py +193 -0
  26. src/metrics/score_distribution.py +13 -4
  27. src/metrics/semantically_same_ranking.py +448 -0
  28. src/metrics/tpfpfn.py +193 -0
  29. src/models/__init__.py +2 -1
  30. src/models/sequence_classification_with_pooler.py +65 -30
  31. src/predict.py +2 -2
  32. src/serializer/__init__.py +4 -1
  33. src/serializer/interface.py +1 -2
  34. src/serializer/json.py +7 -121
  35. src/start_demo.py +3 -2
  36. src/train.py +2 -3
  37. src/utils/graph_utils.py +47 -0
  38. src/utils/inference_utils.py +4 -1
  39. src/utils/pdf_utils/process_pdf.py +1 -1
src/analysis/combine_job_returns.py CHANGED
@@ -1,4 +1,5 @@
1
  import pyrootutils
 
2
 
3
  root = pyrootutils.setup_root(
4
  search_from=__file__,
@@ -6,14 +7,18 @@ root = pyrootutils.setup_root(
6
  pythonpath=True,
7
  dotenv=False,
8
  )
9
-
10
  import argparse
 
11
  import os
 
12
 
 
13
  import pandas as pd
14
 
15
  from src.analysis.common import read_nested_jsons
16
 
 
 
17
 
18
  def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
19
  parts = path_and_maybe_id.split(separator, 1)
@@ -22,7 +27,7 @@ def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[
22
  return parts[0], parts[1]
23
 
24
 
25
- def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> dict[str, str]:
26
  with open(paths_file, "r") as f:
27
  paths_maybe_with_ids = f.readlines()
28
  ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
@@ -31,10 +36,40 @@ def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> dic
31
  file_base_name, ext = os.path.splitext(file_name)
32
  file_name = f"{file_base_name}.aggregated{ext}"
33
  file_paths = [os.path.join(path, file_name) for path in paths]
34
- return {
35
- id if id is not None else f"idx={idx}": path
36
  for idx, (id, path) in enumerate(zip(ids, file_paths))
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def main(
@@ -46,17 +81,36 @@ def main(
46
  format: str,
47
  transpose: bool = False,
48
  unpack_multirun_results: bool = False,
 
49
  in_percent: bool = False,
50
  reset_index: bool = False,
 
 
 
 
 
51
  ):
52
  file_paths = get_file_paths(
53
  paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
54
  )
55
  data = read_nested_jsons(json_paths=file_paths)
56
 
 
 
57
  if columns is not None:
58
- columns_multi_index = [tuple(col.split("/")) for col in columns]
 
 
 
 
 
 
 
59
  try:
 
 
 
 
60
  data_series = [data[col] for col in columns_multi_index]
61
  except KeyError as e:
62
  print(
@@ -88,35 +142,129 @@ def main(
88
  for level in sorted(unique_column_levels, reverse=True):
89
  data.columns = data.columns.droplevel(level)
90
 
91
- if unpack_multirun_results:
92
  index_names = list(data.index.names)
93
- data_series_lists = data.unstack()
94
- data = pd.DataFrame.from_records(
95
- data_series_lists.values, index=data_series_lists.index
96
- ).stack()
97
- for _, index_name in enumerate(index_names):
98
- data = data.unstack(index_name)
99
- data = data.T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # needs to happen before rounding, otherwise the rounding will be off
102
  if in_percent:
103
- data = data * 100
 
104
 
105
  if round_precision is not None:
106
  data = data.round(round_precision)
107
 
108
  # needs to happen before transposing
109
  if format == "markdown_mean_and_std":
110
- if "mean" not in data.columns or "std" not in data.columns:
111
- raise ValueError("Columns 'mean' and 'std' are required for this format.")
112
- # create a single column with mean and std in the format: mean ± std
113
- data = pd.DataFrame(
114
- data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if transpose:
118
  data = data.T
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if reset_index:
121
  data = data.reset_index()
122
 
@@ -148,7 +296,17 @@ if __name__ == "__main__":
148
  parser.add_argument(
149
  "--unpack-multirun-results", action="store_true", help="Unpack multirun results"
150
  )
 
 
 
 
 
151
  parser.add_argument("--transpose", action="store_true", help="Transpose the table")
 
 
 
 
 
152
  parser.add_argument(
153
  "--round-precision",
154
  type=int,
@@ -160,6 +318,34 @@ if __name__ == "__main__":
160
  parser.add_argument(
161
  "--reset-index", action="store_true", help="Reset the index of the combined job returns"
162
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  parser.add_argument(
164
  "--format",
165
  type=str,
 
1
  import pyrootutils
2
+ from pandas import MultiIndex
3
 
4
  root = pyrootutils.setup_root(
5
  search_from=__file__,
 
7
  pythonpath=True,
8
  dotenv=False,
9
  )
 
10
  import argparse
11
+ import logging
12
  import os
13
+ from typing import Iterable, List, Optional, Tuple
14
 
15
+ import numpy as np
16
  import pandas as pd
17
 
18
  from src.analysis.common import read_nested_jsons
19
 
20
+ logger = logging.getLogger(__name__)
21
+
22
 
23
  def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
24
  parts = path_and_maybe_id.split(separator, 1)
 
27
  return parts[0], parts[1]
28
 
29
 
30
+ def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> List[Tuple[str, str]]:
31
  with open(paths_file, "r") as f:
32
  paths_maybe_with_ids = f.readlines()
33
  ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
 
36
  file_base_name, ext = os.path.splitext(file_name)
37
  file_name = f"{file_base_name}.aggregated{ext}"
38
  file_paths = [os.path.join(path, file_name) for path in paths]
39
+ return [
40
+ (id if id is not None else f"idx={idx}", path)
41
  for idx, (id, path) in enumerate(zip(ids, file_paths))
42
+ ]
43
+
44
+
45
+ def get_job_id_col(index: pd.MultiIndex) -> Optional[Tuple]:
46
+ for idx in index:
47
+ if "job_id" in idx:
48
+ return idx
49
+ return None
50
+
51
+
52
+ def stringify(value: str | int | float | None | tuple | list) -> str:
53
+ if isinstance(value, str):
54
+ return value
55
+ if value is None:
56
+ return ""
57
+ if isinstance(value, float) and np.isnan(value):
58
+ return ""
59
+ if isinstance(value, (int, float)):
60
+ return str(value)
61
+ if isinstance(value, Iterable):
62
+ entries = [stringify(v) for v in value]
63
+ return "/".join(v for v in entries if v)
64
+ return value
65
+
66
+
67
+ def remove_part_from_multi_index(index: pd.MultiIndex, part: str) -> pd.MultiIndex:
68
+ new_index = []
69
+ for idx in index:
70
+ new_idx = tuple([i for i in idx if i != part])
71
+ new_index.append(new_idx)
72
+ return MultiIndex.from_tuples(new_index)
73
 
74
 
75
  def main(
 
81
  format: str,
82
  transpose: bool = False,
83
  unpack_multirun_results: bool = False,
84
+ unpack_multirun_results_with_job_id: bool = False,
85
  in_percent: bool = False,
86
  reset_index: bool = False,
87
+ sort_columns: bool = False,
88
+ stringify_column_names: bool = False,
89
+ column_regex_blacklist: Optional[List[str]] = None,
90
+ column_regex_whitelist: Optional[List[str]] = None,
91
+ replace_in_col_names: Optional[List[Tuple[str, str]]] = None,
92
  ):
93
  file_paths = get_file_paths(
94
  paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
95
  )
96
  data = read_nested_jsons(json_paths=file_paths)
97
 
98
+ job_id_col = get_job_id_col(data.columns)
99
+
100
  if columns is not None:
101
+ columns_multi_index = [
102
+ tuple([part or np.nan for part in col.split("/")]) for col in columns
103
+ ]
104
+ if unpack_multirun_results_with_job_id:
105
+ if job_id_col is None:
106
+ raise ValueError("Job ID column not found in the data.")
107
+ if job_id_col not in columns_multi_index:
108
+ columns_multi_index.append(job_id_col)
109
  try:
110
+ available_cols = data.columns.tolist()
111
+ for col in columns_multi_index:
112
+ if col not in available_cols:
113
+ raise KeyError(f"Column {col} not found in the data.")
114
  data_series = [data[col] for col in columns_multi_index]
115
  except KeyError as e:
116
  print(
 
142
  for level in sorted(unique_column_levels, reverse=True):
143
  data.columns = data.columns.droplevel(level)
144
 
145
+ if unpack_multirun_results or unpack_multirun_results_with_job_id:
146
  index_names = list(data.index.names)
147
+ data_series_lists = data.copy()
148
+ job_ids = None
149
+ if job_id_col in data_series_lists.columns:
150
+ job_ids_series = data_series_lists.pop(job_id_col)
151
+ job_ids_frame = pd.DataFrame(pd.DataFrame.from_records(job_ids_series.values))
152
+ job_ids_frame.index = job_ids_series.index
153
+ # check that all rows are identical
154
+ if job_ids_frame.nunique().max():
155
+ job_ids = job_ids_frame.iloc[0]
156
+ else:
157
+ logger.warning(
158
+ "Job IDs are not identical across all rows. Cannot unpack "
159
+ "multirun results with job ids as columns."
160
+ )
161
+
162
+ while not isinstance(data_series_lists, pd.Series):
163
+ data_series_lists = data_series_lists.stack(future_stack=True)
164
+ data_series_lists = data_series_lists.dropna()
165
+ data = pd.DataFrame.from_records(data_series_lists.values, index=data_series_lists.index)
166
+ if job_ids is not None:
167
+ data.columns = job_ids
168
+ num_col_levels = data.index.nlevels - len(index_names)
169
+ for _ in range(num_col_levels):
170
+ data = data.unstack()
171
+ data.columns = data.columns.swaplevel(0, -1)
172
+ data = data.dropna(how="all", axis="columns")
173
 
174
  # needs to happen before rounding, otherwise the rounding will be off
175
  if in_percent:
176
+ float_columns = data.select_dtypes(include=["float64", "float32"]).columns
177
+ data[float_columns] = data[float_columns] * 100
178
 
179
  if round_precision is not None:
180
  data = data.round(round_precision)
181
 
182
  # needs to happen before transposing
183
  if format == "markdown_mean_and_std":
184
+ if data.columns.nlevels == 1:
185
+ data.columns = pd.MultiIndex.from_tuples([(col,) for col in data.columns.tolist()])
186
+
187
+ # get mean columns
188
+ mean_col_names = [col for col in data.columns if "mean" in col]
189
+ mean_columns = data[mean_col_names].copy()
190
+ # remove all "mean" from col names
191
+ mean_columns.columns = remove_part_from_multi_index(mean_columns.columns, "mean")
192
+ # get std columns
193
+ std_col_names = [col for col in data.columns if "std" in col]
194
+ std_columns = data[std_col_names].copy()
195
+ # remove all "std" from col names
196
+ std_columns.columns = remove_part_from_multi_index(std_columns.columns, "std")
197
+ # sanity check
198
+ if not mean_columns.columns.equals(std_columns.columns):
199
+ raise ValueError("Mean and std columns do not match.")
200
+ mean_and_std = mean_columns.astype(str) + " ± " + std_columns.astype(str)
201
+ mean_and_std.columns = [
202
+ ("mean ± std",) + (tuple(col) if col != ((),) else ()) for col in mean_columns.columns
203
+ ]
204
+ # remove mean and std columns from data
205
+ # we can not use drop because the columns is a multiindex that may contain NaNs
206
+ other_cols = [
207
+ col for col in data.columns if col not in set(mean_col_names + std_col_names)
208
+ ]
209
+ data = data[other_cols]
210
+ # add mean and std columns to data
211
+ data = pd.concat([data, mean_and_std], axis=1)
212
+ if data.columns.nlevels == 1:
213
+ data.columns = data.columns.to_flat_index()
214
+ data.columns = [
215
+ "/".join(col) if isinstance(col, tuple) else col for col in data.columns
216
+ ]
217
 
218
  if transpose:
219
  data = data.T
220
 
221
+ if sort_columns:
222
+ # sort columns to get a deterministic order
223
+ data = data.sort_index(axis=1)
224
+
225
+ if stringify_column_names:
226
+ # Convert MultiIndex columns to string representation
227
+ data.columns = data.columns.map(stringify)
228
+
229
+ if column_regex_blacklist is not None:
230
+ # Remove columns that match any of the regex patterns in the blacklist
231
+ for pattern in column_regex_blacklist:
232
+ data = data.loc[:, ~data.columns.str.contains(pattern, regex=True)]
233
+
234
+ if column_regex_whitelist is not None:
235
+ # keep only columns that match any of the regex patterns in the whitelist
236
+ data = data.loc[
237
+ :, data.columns.str.contains("|".join(column_regex_whitelist), regex=True)
238
+ ]
239
+
240
+ if replace_in_col_names is not None:
241
+ for old_value, new_value in replace_in_col_names:
242
+ data.columns = data.columns.str.replace(old_value, new_value, regex=False)
243
+
244
+ else:
245
+ if column_regex_blacklist is not None:
246
+ logger.warning(
247
+ "Column regex blacklist is ignored when stringify_column_names is False."
248
+ )
249
+ if column_regex_whitelist is not None:
250
+ logger.warning(
251
+ "Column regex whitelist is ignored when stringify_column_names is False."
252
+ )
253
+ if replace_in_col_names is not None:
254
+ logger.warning(
255
+ "Replace in column names is ignored when stringify_column_names is False."
256
+ )
257
+
258
+ # remove empty rows
259
+ # get rows that contain only nan or "nan ± nan"
260
+ empty_rows = data.apply(
261
+ lambda row: all(
262
+ pd.isna(value) or (isinstance(value, str) and value == "nan ± nan") for value in row
263
+ ),
264
+ axis=1,
265
+ )
266
+ data = data[~empty_rows]
267
+
268
  if reset_index:
269
  data = data.reset_index()
270
 
 
296
  parser.add_argument(
297
  "--unpack-multirun-results", action="store_true", help="Unpack multirun results"
298
  )
299
+ parser.add_argument(
300
+ "--unpack-multirun-results-with-job-id",
301
+ action="store_true",
302
+ help="Unpack multirun results with job ID",
303
+ )
304
  parser.add_argument("--transpose", action="store_true", help="Transpose the table")
305
+ parser.add_argument(
306
+ "--sort-columns",
307
+ action="store_true",
308
+ help="Sort the columns of the combined job returns",
309
+ )
310
  parser.add_argument(
311
  "--round-precision",
312
  type=int,
 
318
  parser.add_argument(
319
  "--reset-index", action="store_true", help="Reset the index of the combined job returns"
320
  )
321
+ parser.add_argument(
322
+ "--stringify-column-names",
323
+ action="store_true",
324
+ help="Stringify the column names of the combined job returns (useful for multi-index columns)",
325
+ )
326
+ parser.add_argument(
327
+ "--column-regex-blacklist",
328
+ type=str,
329
+ nargs="+",
330
+ default=None,
331
+ help="List of regex patterns to match column names. "
332
+ "Columns that match any of the patterns will be removed.",
333
+ )
334
+ parser.add_argument(
335
+ "--column-regex-whitelist",
336
+ type=str,
337
+ nargs="+",
338
+ default=None,
339
+ help="List of regex patterns to match column names. "
340
+ "Only columns that match any of the patterns will be kept.",
341
+ )
342
+ parser.add_argument(
343
+ "--replace-in-col-names",
344
+ type=lambda s: s.split(":", 1),
345
+ nargs="+",
346
+ default=None,
347
+ help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
348
+ )
349
  parser.add_argument(
350
  "--format",
351
  type=str,
src/analysis/common.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- from typing import Dict, List, Optional
3
 
4
  import pandas as pd
5
 
@@ -26,21 +26,42 @@ def read_nested_json(path: str) -> pd.DataFrame:
26
 
27
 
28
  def read_nested_jsons(
29
- json_paths: Dict[str, str],
30
  default_key_values: Optional[Dict[str, str]] = None,
31
  column_level_names: Optional[List[str]] = None,
32
  ) -> pd.DataFrame:
33
- identifier_strings = json_paths.keys()
34
- dfs = [read_nested_json(json_paths[identifier_str]) for identifier_str in identifier_strings]
35
  new_index_levels = pd.MultiIndex.from_frame(
36
  pd.DataFrame(
37
  [
38
  parse_identifier(identifier_str, default_key_values or {})
39
- for identifier_str in identifier_strings
40
  ]
41
  )
42
  )
43
- dfs_concat = pd.concat(dfs, keys=list(new_index_levels), names=new_index_levels.names, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  dfs_concat.columns = pd.MultiIndex.from_tuples(
45
  [col.split("/") for col in dfs_concat.columns], names=column_level_names
46
  )
 
1
  import json
2
+ from typing import Dict, List, Optional, Tuple
3
 
4
  import pandas as pd
5
 
 
26
 
27
 
28
  def read_nested_jsons(
29
+ json_paths: List[Tuple[str, str]],
30
  default_key_values: Optional[Dict[str, str]] = None,
31
  column_level_names: Optional[List[str]] = None,
32
  ) -> pd.DataFrame:
33
+ dfs = [read_nested_json(json_path) for identifier_str, json_path in json_paths]
 
34
  new_index_levels = pd.MultiIndex.from_frame(
35
  pd.DataFrame(
36
  [
37
  parse_identifier(identifier_str, default_key_values or {})
38
+ for identifier_str, _ in json_paths
39
  ]
40
  )
41
  )
42
+ if len(set(list(new_index_levels))) == len(list(new_index_levels)):
43
+ dfs_concat = pd.concat(
44
+ dfs, keys=list(new_index_levels), names=new_index_levels.names, axis=0
45
+ )
46
+ else:
47
+ dfs_new = []
48
+ ids_unique = []
49
+ for identifier_str in new_index_levels:
50
+ if identifier_str not in ids_unique:
51
+ ids_unique.append(identifier_str)
52
+ # first combine the dataframes with same ids along the columns
53
+ for identifier_str in ids_unique:
54
+ dfs_with_id = [df for df, idx in zip(dfs, new_index_levels) if idx == identifier_str]
55
+ # assert that all columns are distinct
56
+ if len(set([tuple(col) for df in dfs_with_id for col in df.columns])) != sum(
57
+ [len(df.columns) for df in dfs_with_id]
58
+ ):
59
+ raise ValueError(
60
+ "There are duplicate columns across the dataframes with the same identifier."
61
+ )
62
+ dfs_id_concat = pd.concat(dfs_with_id, axis=1)
63
+ dfs_new.append(dfs_id_concat)
64
+ dfs_concat = pd.concat(dfs_new, keys=ids_unique, names=new_index_levels.names, axis=0)
65
  dfs_concat.columns = pd.MultiIndex.from_tuples(
66
  [col.split("/") for col in dfs_concat.columns], names=column_level_names
67
  )
src/analysis/compare_job_returns.py CHANGED
@@ -173,7 +173,7 @@ def combine_job_returns_and_plot(
173
 
174
  if job_return_paths is not None:
175
  df_all = read_nested_jsons(
176
- json_paths=job_return_paths,
177
  default_key_values=default_key_values,
178
  column_level_names=column_level_names,
179
  )
 
173
 
174
  if job_return_paths is not None:
175
  df_all = read_nested_jsons(
176
+ json_paths=list(job_return_paths.items()),
177
  default_key_values=default_key_values,
178
  column_level_names=column_level_names,
179
  )
src/analysis/format_metric_results.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+ from pie_modules.utils import flatten_dict
9
+
10
+
11
+ def str2record(s: str | None, sep_parts: str = "-", sep_k_v: str = "=") -> pd.Series:
12
+ if s is None or s.strip() == "" or s == "None":
13
+ return pd.Series()
14
+ return pd.Series(dict(k_v.split(sep_k_v, 1) for k_v in s.split(sep_parts)))
15
+
16
+
17
+ def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
18
+ parts = path_and_maybe_id.split(separator, 1)
19
+ if len(parts) == 1:
20
+ return None, parts[0]
21
+ return parts[0], parts[1]
22
+
23
+
24
+ def load_data_from_json(path: str | Path) -> pd.DataFrame:
25
+ with open(path, "r") as f:
26
+ data_json = json.load(f)
27
+ data_flat = flatten_dict(data_json)
28
+ return pd.DataFrame(data_flat)
29
+
30
+
31
+ def main(
32
+ path: str | Path,
33
+ remove_col_prefix: str | None = None,
34
+ sparse_col_prefix: str | None = None,
35
+ tail_cols: list[str] | None = None,
36
+ sort_cols: list[str] | None = None,
37
+ split_col: str | None = None,
38
+ replace_in_col_names: list[tuple[str, str]] | None = None,
39
+ round_precision: int | None = None,
40
+ in_percent: bool = False,
41
+ common_prefix_separator: str | None = None,
42
+ column_regex_blacklist: list[str] | None = None,
43
+ column_regex_whitelist: list[str] | None = None,
44
+ format: str = "markdown",
45
+ ) -> None:
46
+
47
+ if str(path).lower().endswith(".json"):
48
+ result = load_data_from_json(path)
49
+ elif str(path).lower().endswith(".txt"):
50
+ with open(path, "r") as f:
51
+ index_data = [separate_path_and_id(line.strip()) for line in f.readlines()]
52
+ data_list = []
53
+ for meta_id, meta_path in index_data:
54
+ data = load_data_from_json(os.path.join(meta_path, "job_return_value.json"))
55
+ if meta_id is not None:
56
+ job_id_prefix = meta_id.replace(",", "-")
57
+ data["job_id"] = job_id_prefix + "-" + data["job_id"].astype(str)
58
+ data = data.set_index("job_id")
59
+ data_list.append(data)
60
+ result = pd.concat(data_list, axis=1).reset_index()
61
+ else:
62
+ raise ValueError("Unsupported file format. Please provide a .json or .txt file.")
63
+
64
+ if remove_col_prefix is not None:
65
+ result.columns = result.columns.str.replace(r"^" + remove_col_prefix, "", regex=True)
66
+
67
+ if sparse_col_prefix is not None:
68
+ # get all columns that contain just one not-nan value
69
+ # number_of_non_nan_values = len(df) - df.isna().sum()
70
+ # df_sparse = df.loc[:, number_of_non_nan_values == 1]
71
+ sparse_cols = [col for col in result.columns if col.startswith(sparse_col_prefix)]
72
+ other_cols = [col for col in result.columns if col not in sparse_cols]
73
+
74
+ value_col = f"{sparse_col_prefix}value"
75
+ name_col = f"{sparse_col_prefix}name"
76
+ result = result.melt(
77
+ id_vars=other_cols, value_vars=sparse_cols, var_name=name_col, value_name=value_col
78
+ ).dropna(
79
+ subset=[value_col]
80
+ ) # keep rows with a value
81
+ # strip the "f1-" prefix, leaving just the numeric threshold
82
+ result[name_col] = result[name_col].str.replace(r"^" + sparse_col_prefix, "", regex=True)
83
+ # convert the column to numeric (if possible)
84
+ try:
85
+ result[name_col] = pd.to_numeric(result[name_col])
86
+ except ValueError:
87
+ # if it fails, just keep it as a string
88
+ pass
89
+
90
+ if split_col is not None:
91
+ new_frame = result[split_col].apply(str2record)
92
+ result = pd.concat([result.drop(columns=[split_col]), new_frame], axis=1)
93
+
94
+ if in_percent:
95
+ float_columns = result.select_dtypes(include=["float64", "float32"]).columns
96
+ result[float_columns] = result[float_columns] * 100
97
+
98
+ if round_precision is not None:
99
+ # round all columns to the given precision
100
+ result = result.round(round_precision)
101
+
102
+ if common_prefix_separator is not None:
103
+ # remove common prefix from values in all string columns
104
+ obj_columns = result.select_dtypes(include=["object"]).columns
105
+ for obj_col in obj_columns:
106
+ # get the common prefix
107
+ common_prefix = os.path.commonprefix(result[obj_col].dropna().astype(str).tolist())
108
+ # find last occurrence of the common_prefix_separator
109
+ last_occurrence = common_prefix.rfind(common_prefix_separator)
110
+ if last_occurrence != -1:
111
+ # truncate the common prefix after the last occurrence of the separator
112
+ common_prefix = common_prefix[: last_occurrence + len(common_prefix_separator)]
113
+ # remove the common prefix (including the separator) from the column
114
+ result[obj_col] = result[obj_col].str.replace(r"^" + common_prefix, "", regex=True)
115
+
116
+ # sort columns to get a deterministic order
117
+ result = result.sort_index(axis=1)
118
+
119
+ if tail_cols is not None:
120
+ front_cols = [c for c in result.columns if c not in tail_cols]
121
+ result = result[front_cols + tail_cols]
122
+
123
+ if sort_cols is not None:
124
+ result = result.sort_values(sort_cols)
125
+ # also move the sort columns to the front
126
+ result = result[sort_cols + [c for c in result.columns if c not in sort_cols]]
127
+
128
+ if column_regex_blacklist is not None:
129
+ # remove columns that match any of the regex patterns in the blacklist
130
+ for pattern in column_regex_blacklist:
131
+ result = result.loc[:, ~result.columns.str.contains(pattern, regex=True)]
132
+
133
+ if column_regex_whitelist is not None:
134
+ # keep only columns that match any of the regex patterns in the whitelist
135
+ result = result.loc[
136
+ :, result.columns.str.contains("|".join(column_regex_whitelist), regex=True)
137
+ ]
138
+
139
+ if replace_in_col_names is not None:
140
+ for old_value, new_value in replace_in_col_names:
141
+ result.columns = result.columns.str.replace(old_value, new_value, regex=False)
142
+
143
+ if format == "markdown":
144
+ result_str = result.to_markdown(index=False)
145
+ elif format == "csv":
146
+ result_str = result.to_csv(index=False)
147
+ elif format == "tsv":
148
+ result_str = result.to_csv(index=False, sep="\t")
149
+ elif format == "json":
150
+ result_str = result.to_json(orient="records", lines=True)
151
+ else:
152
+ raise ValueError(
153
+ f"Unsupported format: {format}. Supported formats are: markdown, csv, json."
154
+ )
155
+
156
+ print(result_str)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ """
161
+ Example usage:
162
+
163
+ python src/analysis/format_metric_results.py \
164
+ logs/document_evaluation/multiruns/default/2025-05-21_11-59-19/job_return_value.json \
165
+ --remove-col-prefix train/ \
166
+ --sparse-col-prefix f1- \
167
+ --split-col job_id \
168
+ --tail-cols num_positives num_total \
169
+ --sort-cols experiment model \
170
+ --round-precision 4
171
+ """
172
+
173
+ parser = argparse.ArgumentParser(
174
+ description="Process a JSON file containing metric results (from multirun) and print as Markdown table."
175
+ )
176
+ parser.add_argument(
177
+ "path",
178
+ type=str,
179
+ help="Path to the JSON file to process. The JSON file is expected to contain "
180
+ "a (maybe nested) dictionary where each leave entry is a list of values with "
181
+ "the same length.",
182
+ )
183
+ parser.add_argument(
184
+ "--remove-col-prefix",
185
+ type=str,
186
+ default=None,
187
+ help="Prefix to remove from column names.",
188
+ )
189
+ parser.add_argument(
190
+ "--sparse-col-prefix",
191
+ type=str,
192
+ default=None,
193
+ help="Prefix of sparse columns. All sparse columns will be melted into "
194
+ "two columns: <prefix>name and <prefix>value. The name column will "
195
+ "be converted to numeric if possible.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--split-col",
200
+ type=str,
201
+ default=None,
202
+ help="Column to split into multiple columns. The format of the "
203
+ "column entries is expected to be: <key_1>=<value_a>-<key_2>=<value_b>-...",
204
+ )
205
+ parser.add_argument(
206
+ "--tail-cols",
207
+ type=str,
208
+ nargs="+",
209
+ default=None,
210
+ help="Columns to move to the end.",
211
+ )
212
+ parser.add_argument(
213
+ "--sort-cols",
214
+ type=str,
215
+ nargs="+",
216
+ default=None,
217
+ help="Columns to sort by (they will be moved to the front).",
218
+ )
219
+ parser.add_argument(
220
+ "--replace-in-col-names",
221
+ type=lambda s: s.split(":", 1),
222
+ nargs="+",
223
+ default=None,
224
+ help='List of strings in the format "<old_value>:<new_value>" to replace substrings in column names.',
225
+ )
226
+ parser.add_argument(
227
+ "--round-precision",
228
+ type=int,
229
+ default=None,
230
+ help="Number of decimal places to round to.",
231
+ )
232
+ parser.add_argument(
233
+ "--in-percent",
234
+ action="store_true",
235
+ default=False,
236
+ help="If set, all float columns will be multiplied by 100 to convert them to percentages.",
237
+ )
238
+ parser.add_argument(
239
+ "--common-prefix-separator",
240
+ type=str,
241
+ default=None,
242
+ help="For all string columns, remove the common prefix up to the last occurrence of this separator.",
243
+ )
244
+ parser.add_argument(
245
+ "--column-regex-blacklist",
246
+ type=str,
247
+ nargs="+",
248
+ default=None,
249
+ help="List of regex patterns to match column names. "
250
+ "Columns that match any of the patterns will be removed.",
251
+ )
252
+ parser.add_argument(
253
+ "--column-regex-whitelist",
254
+ type=str,
255
+ nargs="+",
256
+ default=None,
257
+ help="List of regex patterns to match column names. "
258
+ "Only columns that match any of the patterns will be kept.",
259
+ )
260
+ parser.add_argument(
261
+ "--format",
262
+ type=str,
263
+ default="markdown",
264
+ choices=["markdown", "csv", "tsv", "json"],
265
+ help="Format to print the result in. Supported formats are: markdown, csv, json.",
266
+ )
267
+
268
+ kwargs = vars(parser.parse_args())
269
+ main(**kwargs)
src/analysis/get_json_field_as_string.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def main(
5
+ paths: list[str],
6
+ field: list[str],
7
+ format: str = "plain",
8
+ ) -> None:
9
+ result = []
10
+ for path in paths:
11
+ with open(path, "r") as f:
12
+ data = json.load(f)
13
+ value = data
14
+ for key in field:
15
+ value = value.get(key)
16
+ if not isinstance(value, list):
17
+ value = [value]
18
+ result.extend(value)
19
+ if format == "plain":
20
+ print(",".join(map(str, result)))
21
+ elif format == "python":
22
+ result_str = str(result)
23
+ print(result_str.replace(" ", ""))
24
+ else:
25
+ raise ValueError(f"Unknown format: {format}")
26
+
27
+
28
+ if __name__ == "__main__":
29
+ import argparse
30
+
31
+ parser = argparse.ArgumentParser(
32
+ description="Get a field from one or more JSON files and print to stdout."
33
+ )
34
+ parser.add_argument(
35
+ "paths",
36
+ type=lambda x: x.split(","),
37
+ help="Comma-separated list of paths to the JSON files to process.",
38
+ )
39
+ parser.add_argument(
40
+ "--field",
41
+ type=str,
42
+ required=True,
43
+ nargs="+",
44
+ help="Field to extract from the JSON files. Can be a nested field by providing multiple entries.",
45
+ )
46
+ parser.add_argument(
47
+ "--format",
48
+ type=str,
49
+ default="plain",
50
+ choices=["plain", "python"],
51
+ )
52
+
53
+ args = parser.parse_args()
54
+ kwargs = vars(args)
55
+ main(**kwargs)
src/analysis/show_inference_params_on_quality_and_throughput.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.express as px
8
+
9
+
10
+ def get_col_name(col: str) -> str:
11
+ parts = [part[1:-1] for part in col[1:-1].split(", ") if part[1:-1] != ""]
12
+ return parts[-1]
13
+
14
+
15
+ def get_idx_entry(s: str, keep_only_last_part: bool = False) -> Tuple[str, str]:
16
+ k, v = s.split("=", 1)
17
+ if keep_only_last_part:
18
+ k = k.split(".")[-1]
19
+ return k, v
20
+
21
+
22
+ def get_idx_dict(job_id: str, keep_only_last_part: bool = False) -> Dict[str, str]:
23
+ return dict(
24
+ get_idx_entry(part, keep_only_last_part=keep_only_last_part) for part in job_id.split("-")
25
+ )
26
+
27
+
28
+ def unflatten_index(
29
+ index: Iterable[str],
30
+ keep_only_last_part: bool = False,
31
+ dtypes: Optional[Dict[str, Any]] = None,
32
+ ) -> pd.MultiIndex:
33
+ as_df = pd.DataFrame.from_records(
34
+ [get_idx_dict(idx, keep_only_last_part=keep_only_last_part) for idx in index]
35
+ )
36
+ if dtypes is not None:
37
+ dtypes_valid = {col: dtype for col, dtype in dtypes.items() if col in as_df.columns}
38
+ as_df = as_df.astype(dtypes_valid)
39
+ return pd.MultiIndex.from_frame(as_df.convert_dtypes())
40
+
41
+
42
+ def col_to_str(col_entries: Iterable[str], names: Iterable[Optional[str]], sep: str) -> str:
43
+ return sep.join(
44
+ [
45
+ f"{name}={col_entry}" if name is not None else col_entry
46
+ for col_entry, name in zip(col_entries, names)
47
+ ]
48
+ )
49
+
50
+
51
+ def flatten_index(index: pd.MultiIndex, names: Optional[List[Optional[str]]] = None) -> pd.Index:
52
+ names = names or index.names
53
+ if names is None:
54
+ raise ValueError("names must be provided if index has no names")
55
+ return pd.Index([col_to_str(col, names=names, sep=",") for col in index])
56
+
57
+
58
+ def prepare_quality_and_throughput_dfs(
59
+ metric_data_path: str,
60
+ job_return_value_path: str,
61
+ char_total: int,
62
+ index_dtypes: Optional[Dict[str, Any]] = None,
63
+ job_id_prefix: Optional[str] = None,
64
+ ) -> Tuple[pd.DataFrame, pd.Series]:
65
+
66
+ with open(metric_data_path) as f:
67
+ data = json.load(f)
68
+
69
+ # save result from above command in "data" (use only last ouf the output line!)
70
+ df = pd.DataFrame.from_dict(data)
71
+ df.columns = [get_col_name(col) for col in df.columns]
72
+ f1_series = df.set_index([col for col in df.columns if col != "f1"])["f1"]
73
+ f1_df = f1_series.apply(lambda x: pd.Series(x)).T
74
+
75
+ with open(job_return_value_path) as f:
76
+ job_return_value = json.load(f)
77
+
78
+ job_ids = job_return_value["job_id"]
79
+ if job_id_prefix is not None:
80
+ job_ids = [
81
+ f"{job_id_prefix},{job_id}" if job_id.strip() != "" else job_id_prefix
82
+ for job_id in job_ids
83
+ ]
84
+ index = unflatten_index(
85
+ job_ids,
86
+ keep_only_last_part=True,
87
+ dtypes=index_dtypes,
88
+ )
89
+ prediction_time_series = pd.Series(
90
+ job_return_value["prediction_time"], index=index, name="prediction_time"
91
+ )
92
+ f1_df.index = prediction_time_series.index
93
+
94
+ k_chars_per_s = char_total / (prediction_time_series * 1000)
95
+ k_chars_per_s.name = "1k_chars_per_s"
96
+
97
+ return f1_df, k_chars_per_s
98
+
99
+
100
+ def get_pareto_front_mask(df: pd.DataFrame, x_col: str, y_col: str) -> pd.Series:
101
+ """
102
+ Return a boolean mask indicating which rows belong to the Pareto front.
103
+ In this version, we assume you want to maximize both x_col and y_col.
104
+
105
+ A point A is said to dominate point B if:
106
+ A[x_col] >= B[x_col] AND
107
+ A[y_col] >= B[y_col] AND
108
+ at least one is strictly greater.
109
+ Then B is not on the Pareto front.
110
+
111
+ Parameters
112
+ ----------
113
+ df : pd.DataFrame
114
+ DataFrame containing the data points.
115
+ x_col : str
116
+ Name of the column to treat as the first objective (maximize).
117
+ y_col : str
118
+ Name of the column to treat as the second objective (maximize).
119
+
120
+ Returns
121
+ -------
122
+ pd.Series
123
+ A boolean Series (aligned with df.index) where True means
124
+ the row is on the Pareto front.
125
+ """
126
+ # Extract the relevant columns as a NumPy array for speed.
127
+ data = df[[x_col, y_col]].values
128
+ n = len(data)
129
+ is_dominated = np.zeros(n, dtype=bool)
130
+
131
+ for i in range(n):
132
+ # If it's already marked dominated, skip checks
133
+ if is_dominated[i]:
134
+ continue
135
+
136
+ for j in range(n):
137
+ if i == j:
138
+ continue
139
+ # Check if j dominates i
140
+ if (
141
+ data[j, 0] >= data[i, 0]
142
+ and data[j, 1] >= data[i, 1]
143
+ and (data[j, 0] > data[i, 0] or data[j, 1] > data[i, 1])
144
+ ):
145
+ is_dominated[i] = True
146
+ break
147
+
148
+ # Return True for points not dominated by any other
149
+ return pd.Series(~is_dominated, index=df.index)
150
+
151
+
152
+ def main(
153
+ job_return_value_path_test: List[str],
154
+ job_return_value_path_val: List[str],
155
+ metric_data_path_test: List[str],
156
+ metric_data_path_val: List[str],
157
+ char_total_test: int,
158
+ char_total_val: int,
159
+ job_id_prefixes: Optional[List[str]] = None,
160
+ metric_filters: Optional[List[str]] = None,
161
+ index_filters: Optional[List[str]] = None,
162
+ index_blacklist: Optional[List[str]] = None,
163
+ label_mapping: Optional[Dict[str, str]] = None,
164
+ plot_method: str = "line", # can be "scatter" or "line"
165
+ pareto_front: bool = False,
166
+ show_as: str = "figure",
167
+ columns: Optional[List[str]] = None,
168
+ color_column: Optional[str] = None,
169
+ ):
170
+ label_mapping = label_mapping or {}
171
+ if job_id_prefixes is not None:
172
+ if len(job_id_prefixes) != len(job_return_value_path_test):
173
+ raise ValueError(
174
+ f"job_id_prefixes ({len(job_id_prefixes)}) and "
175
+ f"job_return_value_path_test ({len(job_return_value_path_test)}) "
176
+ f"must have the same length"
177
+ )
178
+ # replace empty strings with None
179
+ job_id_prefixes_with_none = [
180
+ job_id_prefix if job_id_prefix != "" else None for job_id_prefix in job_id_prefixes
181
+ ]
182
+ else:
183
+ job_id_prefixes_with_none = [None] * len(job_return_value_path_test)
184
+
185
+ # combine input data for test and val
186
+ char_total = {"test": char_total_test, "val": char_total_val}
187
+ metric_data_path = {"test": metric_data_path_test, "val": metric_data_path_val}
188
+ job_return_value_path = {"test": job_return_value_path_test, "val": job_return_value_path_val}
189
+ # prepare dataframes
190
+ common_kwargs = dict(
191
+ index_dtypes={
192
+ "max_argument_distance": int,
193
+ "max_length": int,
194
+ "num_beams": int,
195
+ }
196
+ )
197
+ f1_df_list: Dict[str, List[pd.DataFrame]] = {"test": [], "val": []}
198
+ k_chars_per_s_list: Dict[str, List[pd.Series]] = {"test": [], "val": []}
199
+ for split in metric_data_path:
200
+ if len(metric_data_path[split]) != len(job_return_value_path[split]):
201
+ raise ValueError(
202
+ f"metric_data_path[{split}] ({len(metric_data_path[split])}) and "
203
+ f"job_return_value_path[{split}] ({len(job_return_value_path[split])}) "
204
+ f"must have the same length"
205
+ )
206
+ for current_metric_data_path, current_job_return_value_path, job_id_prefix in zip(
207
+ metric_data_path[split], job_return_value_path[split], job_id_prefixes_with_none
208
+ ):
209
+ current_f1_df, current_k_chars_per_s = prepare_quality_and_throughput_dfs(
210
+ current_metric_data_path,
211
+ current_job_return_value_path,
212
+ char_total=char_total[split],
213
+ job_id_prefix=job_id_prefix,
214
+ **common_kwargs,
215
+ )
216
+ f1_df_list[split].append(current_f1_df)
217
+ k_chars_per_s_list[split].append(current_k_chars_per_s)
218
+ f1_df_dict = {split: pd.concat(f1_df_list[split], axis=0) for split in f1_df_list}
219
+ k_chars_per_s_dict = {
220
+ split: pd.concat(k_chars_per_s_list[split], axis=0) for split in k_chars_per_s_list
221
+ }
222
+
223
+ # combine dataframes for test and val
224
+ f1_df = pd.concat(f1_df_dict, names=["split"] + f1_df_dict["test"].index.names)
225
+ f1_df.columns = [col_to_str(col, names=f1_df.columns.names, sep=",") for col in f1_df.columns]
226
+ k_chars_per_s = pd.concat(
227
+ k_chars_per_s_dict,
228
+ names=["split"] + k_chars_per_s_dict["test"].index.names,
229
+ )
230
+
231
+ # combine quality and throughput data
232
+ df_plot = pd.concat([f1_df, k_chars_per_s], axis=1)
233
+ df_plot = (
234
+ df_plot.reset_index()
235
+ .set_index(list(f1_df.index.names) + [k_chars_per_s.name])
236
+ .unstack("split")
237
+ )
238
+ df_plot.columns = flatten_index(df_plot.columns, names=[None, "split"])
239
+
240
+ # remove all columns that are not needed
241
+ if metric_filters is not None:
242
+ for fil in metric_filters:
243
+ df_plot.drop(columns=[col for col in df_plot.columns if fil not in col], inplace=True)
244
+ df_plot.columns = [col.replace(fil, "") for col in df_plot.columns]
245
+
246
+ # flatten the columns
247
+ df_plot.columns = [
248
+ ",".join([part for part in col.split(",") if part != ""]) for col in df_plot.columns
249
+ ]
250
+
251
+ v: Any
252
+ if index_filters is not None:
253
+ for k_v in index_filters:
254
+ k, v = k_v.split("=")
255
+ if k in common_kwargs["index_dtypes"]:
256
+ v = common_kwargs["index_dtypes"][k](v)
257
+ df_plot = df_plot.xs(v, level=k, axis=0)
258
+
259
+ if index_blacklist is not None:
260
+ for k_v in index_blacklist:
261
+ k, v = k_v.split("=")
262
+ if k in common_kwargs["index_dtypes"]:
263
+ v = common_kwargs["index_dtypes"][k](v)
264
+ df_plot = df_plot.drop(v, level=k, axis=0)
265
+
266
+ if columns is not None:
267
+ df_plot = df_plot[columns]
268
+
269
+ x = "1k_chars_per_s"
270
+ y = df_plot.columns
271
+
272
+ if pareto_front:
273
+ for col in y:
274
+ current_data = df_plot[col].dropna().reset_index(x).copy()
275
+ pareto_front_mask = get_pareto_front_mask(current_data, x_col=x, y_col=col)
276
+ current_data.loc[~pareto_front_mask, col] = np.nan
277
+ current_data_reset = current_data.reset_index().set_index(df_plot.index.names)
278
+ df_plot[col] = current_data_reset[col]
279
+
280
+ # remove nan rows
281
+ df_plot = df_plot.dropna(how="all")
282
+
283
+ # plot
284
+ # Create a custom color sequence (concatenating multiple palettes if needed)
285
+ custom_colors = px.colors.qualitative.Dark24 + px.colors.qualitative.Light24
286
+
287
+ text_cols = list(df_plot.index.names)
288
+ text_cols.remove(x)
289
+ df_plot_reset = df_plot.reset_index()
290
+ if len(text_cols) > 1:
291
+ df_plot_reset[",".join(text_cols)] = (
292
+ df_plot_reset[text_cols].astype(str).agg(", ".join, axis=1)
293
+ )
294
+ text_col = ",".join(text_cols)
295
+
296
+ if show_as == "figure":
297
+ _plot_method = getattr(px, plot_method)
298
+ df_plot_sorted = df_plot_reset.sort_values(by=x)
299
+ fig = _plot_method(
300
+ df_plot_sorted,
301
+ x=x,
302
+ y=y,
303
+ text=text_col if plot_method != "scatter" else None,
304
+ color=color_column,
305
+ color_discrete_sequence=custom_colors,
306
+ hover_data=text_cols,
307
+ )
308
+
309
+ # set connectgaps to True to connect the lines
310
+ fig.update_traces(connectgaps=True)
311
+
312
+ legend_title = "Evaluation Setup"
313
+ if metric_filters:
314
+ whitelist_filters_mapped = [label_mapping.get(fil, fil) for fil in metric_filters]
315
+ legend_title += f" ({', '.join(whitelist_filters_mapped)})"
316
+
317
+ text_cols_mapped = [label_mapping.get(col, col) for col in text_cols]
318
+ title = f"Impact of {', '.join(text_cols_mapped)} on Prediction Quality and Throughput"
319
+ if index_filters:
320
+ index_filters_mapped = [label_mapping.get(fil, fil) for fil in index_filters]
321
+ title += f" ({', '.join(index_filters_mapped)})"
322
+ if pareto_front:
323
+ title += " (Pareto Front)"
324
+
325
+ fig.update_layout(
326
+ xaxis_title="Throughput (1k chars/s)",
327
+ yaxis_title="Quality (F1)",
328
+ title=title,
329
+ # center the title
330
+ title_x=0.2,
331
+ # black title
332
+ title_font=dict(color="black"),
333
+ # change legend title
334
+ legend_title=legend_title,
335
+ font_family="Computer Modern",
336
+ # white background
337
+ plot_bgcolor="white",
338
+ paper_bgcolor="white",
339
+ )
340
+ update_axes_kwargs = dict(
341
+ tickfont=dict(color="black"),
342
+ title_font=dict(color="black"),
343
+ ticks="inside", # ensure tick markers are drawn
344
+ tickcolor="black",
345
+ tickwidth=1,
346
+ ticklen=10,
347
+ linecolor="black",
348
+ # show grid
349
+ gridcolor="lightgray",
350
+ )
351
+ fig.update_yaxes(**update_axes_kwargs)
352
+ fig.update_xaxes(**update_axes_kwargs)
353
+
354
+ fig.show()
355
+ elif show_as == "markdown":
356
+ # Print the DataFrame as a Markdown table
357
+ print(df_plot_reset.to_markdown(index=False, floatfmt=".4f"))
358
+ elif show_as == "json":
359
+ # Print the DataFrame as a JSON object
360
+ print(df_plot_reset.to_json(orient="columns", indent=4))
361
+ else:
362
+ raise ValueError(f"Unknown show_as value: {show_as}. Use 'figure', 'markdown' or 'json'.")
363
+
364
+
365
+ if __name__ == "__main__":
366
+
367
+ """
368
+ # Example usage 1 (pipeline model, data from data source: https://github.com/ArneBinder/pie-document-level/issues/388#issuecomment-2752829257):
369
+ python src/analysis/show_inference_params_on_quality_and_throughput.py \
370
+ --job-return-value-path-test logs/prediction/multiruns/default/2025-03-26_01-31-05/job_return_value.json \
371
+ --job-return-value-path-val logs/prediction/multiruns/default/2025-03-26_16-49-36/job_return_value.json \
372
+ --metric-data-path-test data/evaluation/argumentation_structure/inference_pipeline_test.json \
373
+ --metric-data-path-val data/evaluation/argumentation_structure/inference_pipeline_validation.json \
374
+ --metric-filters task=are discont_comp=true split=val
375
+
376
+ # Example usage 2 (joint model, data from: https://github.com/ArneBinder/pie-document-level/issues/390#issuecomment-2759888004)
377
+ python src/analysis/show_inference_params_on_quality_and_throughput.py \
378
+ --job-return-value-path-test logs/prediction/multiruns/default/2025-03-28_01-34-07/job_return_value.json \
379
+ --job-return-value-path-val logs/prediction/multiruns/default/2025-03-28_02-57-00/job_return_value.json \
380
+ --metric-data-path-test data/evaluation/argumentation_structure/inference_joint_test.json \
381
+ --metric-data-path-val data/evaluation/argumentation_structure/inference_joint_validation.json \
382
+ --metric-filters task=are discont_comp=true split=val \
383
+ --plot-method scatter
384
+ """
385
+
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument(
388
+ "--job-return-value-path-test",
389
+ type=str,
390
+ nargs="+",
391
+ required=True,
392
+ )
393
+ parser.add_argument(
394
+ "--job-return-value-path-val",
395
+ type=str,
396
+ nargs="+",
397
+ required=True,
398
+ )
399
+ parser.add_argument(
400
+ "--metric-data-path-test",
401
+ type=str,
402
+ nargs="+",
403
+ required=True,
404
+ )
405
+ parser.add_argument(
406
+ "--metric-data-path-val",
407
+ type=str,
408
+ nargs="+",
409
+ required=True,
410
+ )
411
+ parser.add_argument(
412
+ "--job-id-prefixes",
413
+ type=str,
414
+ nargs="*",
415
+ default=None,
416
+ )
417
+ parser.add_argument(
418
+ "--plot-method",
419
+ type=str,
420
+ default="line",
421
+ choices=["scatter", "line"],
422
+ help="Plot method to use (default: line)",
423
+ )
424
+ parser.add_argument(
425
+ "--color-column",
426
+ type=str,
427
+ default=None,
428
+ help="Column to use for colour coding (default: None)",
429
+ )
430
+ parser.add_argument(
431
+ "--metric-filters",
432
+ type=str,
433
+ nargs="*",
434
+ default=None,
435
+ help="Filters to apply to the metric data in the format 'key=value'",
436
+ )
437
+ parser.add_argument(
438
+ "--index-filters",
439
+ type=str,
440
+ nargs="*",
441
+ default=None,
442
+ help="Filters to apply to the index data in the format 'key=value'",
443
+ )
444
+ parser.add_argument(
445
+ "--index-blacklist",
446
+ type=str,
447
+ nargs="*",
448
+ default=None,
449
+ help="Blacklist to apply to the index data in the format 'key=value'",
450
+ )
451
+ parser.add_argument(
452
+ "--columns",
453
+ type=str,
454
+ nargs="*",
455
+ default=None,
456
+ help="Columns to plot (default: all)",
457
+ )
458
+ parser.add_argument(
459
+ "--pareto-front",
460
+ action="store_true",
461
+ help="Whether to show only the pareto front",
462
+ )
463
+ parser.add_argument(
464
+ "--show-as",
465
+ type=str,
466
+ default="figure",
467
+ choices=["figure", "markdown", "json"],
468
+ help="How to show the results (default: figure)",
469
+ )
470
+
471
+ kwargs = vars(parser.parse_args())
472
+
473
+ main(
474
+ char_total_test=383154,
475
+ char_total_val=182794,
476
+ label_mapping={
477
+ "max_argument_distance": "Max. Argument Distance",
478
+ "max_length": "Max. Length",
479
+ "num_beams": "Num. Beams",
480
+ "task=are": "ARE",
481
+ "discont_comp=true": "Discont. Comp.",
482
+ "split=val": "Validation Split",
483
+ },
484
+ **kwargs,
485
+ )
src/datamodules/__init__.py CHANGED
@@ -1 +1 @@
1
- from .datamodule import PieDataModule
 
1
+ from .datamodule_with_sampler import PieDataModuleWithSampler
src/datamodules/datamodule_with_sampler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Union
3
+
4
+ from pytorch_ie import PieDataModule
5
+ from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset
6
+ from torch.utils.data import DataLoader, Sampler
7
+
8
+ from .components.sampler import ImbalancedDatasetSampler
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class PieDataModuleWithSampler(PieDataModule):
14
+
15
+ def __init__(
16
+ self,
17
+ train_sampler: Optional[str] = None,
18
+ dont_shuffle_train: bool = False,
19
+ **kwargs,
20
+ ) -> None:
21
+ super().__init__(**kwargs)
22
+
23
+ self.train_sampler_name = train_sampler
24
+ self.dont_shuffle_train = dont_shuffle_train
25
+
26
+ def get_train_sampler(
27
+ self,
28
+ dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset],
29
+ ) -> Optional[Sampler]:
30
+ if self.train_sampler_name is None:
31
+ return None
32
+ elif self.train_sampler_name == "imbalanced_dataset":
33
+ # for now, this work only with targets that have a single entry
34
+ return ImbalancedDatasetSampler(
35
+ dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
36
+ )
37
+ else:
38
+ raise ValueError(f"unknown sampler name: {self.train_sampler_name}")
39
+
40
+ def train_dataloader(self) -> DataLoader:
41
+ ds = self.data_split(self.train_split)
42
+ sampler = self.get_train_sampler(dataset=ds)
43
+ # don't shuffle if we explicitly set dont_shuffle_train,
44
+ # streamed datasets or if we use a sampler or
45
+ shuffle = not (
46
+ self.dont_shuffle_train
47
+ or isinstance(ds, IterableTaskEncodingDataset)
48
+ or sampler is not None
49
+ )
50
+
51
+ if not shuffle:
52
+ logger.warning("not shuffling train dataloader")
53
+ return DataLoader(
54
+ dataset=ds,
55
+ sampler=sampler,
56
+ collate_fn=self.taskmodule.collate,
57
+ shuffle=shuffle,
58
+ **self.dataloader_kwargs,
59
+ )
src/dataset/processing.py CHANGED
@@ -1,9 +1,16 @@
1
- from typing import Callable, Type, Union
 
 
2
 
3
  from pie_datasets import Dataset, DatasetDict
 
4
  from pytorch_ie import Document
 
 
5
  from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
6
 
 
 
7
 
8
  # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
9
  # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
@@ -11,7 +18,7 @@ def apply_func_to_splits(
11
  dataset: DatasetDict,
12
  function: Union[str, Callable],
13
  result_document_type: Type[Document],
14
- **kwargs
15
  ):
16
  resolved_func = resolve_target(function)
17
  resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
@@ -23,7 +30,85 @@ def apply_func_to_splits(
23
  batched=True,
24
  batch_size=len(split),
25
  result_document_type=resolved_document_type,
26
- **kwargs
27
  )
28
  result_dict[split_name] = converted_dataset
29
  return DatasetDict(result_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Callable, Dict, List, Optional, Type, TypeVar, Union
4
 
5
  from pie_datasets import Dataset, DatasetDict
6
+ from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
7
  from pytorch_ie import Document
8
+ from pytorch_ie.annotations import BinaryRelation, Span
9
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
10
  from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
 
15
  # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
16
  # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
 
18
  dataset: DatasetDict,
19
  function: Union[str, Callable],
20
  result_document_type: Type[Document],
21
+ **kwargs,
22
  ):
23
  resolved_func = resolve_target(function)
24
  resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
 
30
  batched=True,
31
  batch_size=len(split),
32
  result_document_type=resolved_document_type,
33
+ **kwargs,
34
  )
35
  result_dict[split_name] = converted_dataset
36
  return DatasetDict(result_dict)
37
+
38
+
39
+ S = TypeVar("S", bound=Span)
40
+
41
+
42
+ def shift_span(span: S, offset: int) -> S:
43
+ """Shift the start and end of a span by a given offset."""
44
+ return span.copy(start=span.start + offset, end=span.end + offset)
45
+
46
+
47
+ D = TypeVar("D", bound=TextDocumentWithLabeledSpansAndBinaryRelations)
48
+
49
+
50
+ def add_predicted_semantically_same_relations_to_document(
51
+ document: D,
52
+ doc_id2docs_with_predictions: Dict[
53
+ str, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
54
+ ],
55
+ relation_label: str,
56
+ argument_label_blacklist: Optional[List[str]] = None,
57
+ verbose: bool = False,
58
+ ) -> D:
59
+
60
+ # create lookup for detached versions of the spans (attached span != detached span even if they are the same)
61
+ span2span = {span.copy(): span for span in document.labeled_spans}
62
+ for text_pair_doc_with_preds in doc_id2docs_with_predictions.get(document.id, []):
63
+ offset = text_pair_doc_with_preds.metadata["original_doc_span"]["start"]
64
+ offset_pair = text_pair_doc_with_preds.metadata["original_doc_span_pair"]["start"]
65
+ for coref_rel in text_pair_doc_with_preds.binary_coref_relations.predictions:
66
+ head = shift_span(coref_rel.head, offset=offset)
67
+ if head not in span2span:
68
+ if verbose:
69
+ logger.warning(f"doc_id={document.id}: Head span {head} not found.")
70
+ continue
71
+ tail = shift_span(coref_rel.tail, offset=offset_pair)
72
+ if tail not in span2span:
73
+ if verbose:
74
+ logger.warning(f"doc_id={document.id}: Tail span {tail} not found.")
75
+ continue
76
+ if argument_label_blacklist is not None and (
77
+ span2span[head].label in argument_label_blacklist
78
+ or span2span[tail].label in argument_label_blacklist
79
+ ):
80
+ continue
81
+ new_rel = BinaryRelation(
82
+ head=span2span[head],
83
+ tail=span2span[tail],
84
+ label=relation_label,
85
+ score=coref_rel.score,
86
+ )
87
+ document.binary_relations.predictions.append(new_rel)
88
+ return document
89
+
90
+
91
+ def integrate_coref_predictions_from_text_pair_documents(
92
+ dataset: DatasetDict, data_dir: str, **kwargs
93
+ ) -> DatasetDict:
94
+
95
+ dataset_with_predictions = DatasetDict.from_json(data_dir=data_dir)
96
+
97
+ for split_name in dataset.keys():
98
+ ds_with_predictions = dataset_with_predictions[split_name]
99
+ original_doc_id2docs = defaultdict(list)
100
+ for doc in ds_with_predictions:
101
+ original_doc_id = doc.metadata["original_doc_id"]
102
+ if original_doc_id != doc.metadata["original_doc_id_pair"]:
103
+ raise ValueError(
104
+ f"Original document IDs do not match: "
105
+ f"{original_doc_id} != {doc.metadata['original_doc_id_pair']}. "
106
+ f"Cross-document coref is not supported."
107
+ )
108
+ original_doc_id2docs[original_doc_id].append(doc)
109
+
110
+ dataset[split_name] = dataset[split_name].map(
111
+ function=add_predicted_semantically_same_relations_to_document,
112
+ fn_kwargs=dict(doc_id2docs_with_predictions=original_doc_id2docs, **kwargs),
113
+ )
114
+ return dataset
src/demo/annotation_utils.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  import logging
3
- from typing import Iterable, Optional, Sequence, Union
4
 
5
  import gradio as gr
6
  from hydra.utils import instantiate
@@ -41,59 +41,6 @@ def get_merger() -> SpansViaRelationMerger:
41
  )
42
 
43
 
44
- def annotate_document(
45
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
46
- argumentation_model: Pipeline,
47
- handle_parts_of_same: bool = False,
48
- ) -> Union[
49
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
50
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
51
- ]:
52
- """Annotate a document with the provided pipeline.
53
-
54
- Args:
55
- document: The document to annotate.
56
- argumentation_model: The pipeline to use for annotation.
57
- handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
58
- """
59
-
60
- # execute prediction pipeline
61
- result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
62
- document, inplace=True
63
- )
64
-
65
- if handle_parts_of_same:
66
- merger = get_merger()
67
- result = merger(result)
68
-
69
- return result
70
-
71
-
72
- def annotate_documents(
73
- documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
74
- argumentation_model: Pipeline,
75
- handle_parts_of_same: bool = False,
76
- ) -> Union[
77
- Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
78
- Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
79
- ]:
80
- """Annotate a sequence of documents with the provided pipeline.
81
-
82
- Args:
83
- documents: The documents to annotate.
84
- argumentation_model: The pipeline to use for annotation.
85
- handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
86
- """
87
- # execute prediction pipeline
88
- result = argumentation_model(documents, inplace=True)
89
-
90
- if handle_parts_of_same:
91
- merger = get_merger()
92
- result = [merger(document) for document in result]
93
-
94
- return result
95
-
96
-
97
  def create_document(
98
  text: str, doc_id: str, split_regex: Optional[str] = None
99
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
@@ -143,14 +90,17 @@ def create_documents(
143
  ]
144
 
145
 
146
- def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
147
  try:
148
  config = parse_config(config_str, format="yaml")
 
 
 
149
 
150
  # for PIE AutoPipeline, we need to handle the revision separately for
151
  # the taskmodule and the model
152
  if (
153
- config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained"
154
  and "revision" in config
155
  ):
156
  revision = config.pop("revision")
 
1
  import json
2
  import logging
3
+ from typing import Iterable, Optional, Sequence
4
 
5
  import gradio as gr
6
  from hydra.utils import instantiate
 
41
  )
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def create_document(
45
  text: str, doc_id: str, split_regex: Optional[str] = None
46
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
 
90
  ]
91
 
92
 
93
+ def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]:
94
  try:
95
  config = parse_config(config_str, format="yaml")
96
+ if config is None or config == {}:
97
+ gr.Warning("Empty argumentation model config provided. No model loaded.")
98
+ return None
99
 
100
  # for PIE AutoPipeline, we need to handle the revision separately for
101
  # the taskmodule and the model
102
  if (
103
+ config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained")
104
  and "revision" in config
105
  ):
106
  revision = config.pop("revision")
src/demo/backend_utils.py CHANGED
@@ -12,10 +12,11 @@ from pie_datasets import Dataset, IterableDataset, load_dataset
12
  from pytorch_ie import Pipeline
13
  from pytorch_ie.documents import (
14
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
 
15
  )
16
  from tqdm import tqdm
17
 
18
- from src.demo.annotation_utils import annotate_documents, create_documents
19
  from src.demo.data_utils import load_text_from_arxiv
20
  from src.demo.rendering_utils import (
21
  RENDER_WITH_DISPLACY,
@@ -54,7 +55,7 @@ def add_annotated_pie_documents(
54
  def process_texts(
55
  texts: Iterable[str],
56
  doc_ids: Iterable[str],
57
- argumentation_model: Pipeline,
58
  retriever: DocumentAwareSpanRetriever,
59
  split_regex_escaped: Optional[str],
60
  handle_parts_of_same: bool = False,
@@ -68,13 +69,21 @@ def process_texts(
68
  doc_ids=doc_ids,
69
  split_regex=split_regex_escaped,
70
  )
71
- if verbose:
72
- gr.Info(f"Annotate {len(pie_documents)} documents...")
73
- pie_documents = annotate_documents(
74
- documents=pie_documents,
75
- argumentation_model=argumentation_model,
76
- handle_parts_of_same=handle_parts_of_same,
77
- )
 
 
 
 
 
 
 
 
78
  add_annotated_pie_documents(
79
  retriever=retriever,
80
  pie_documents=pie_documents,
@@ -93,12 +102,41 @@ def add_annotated_pie_documents_from_dataset(
93
  dataset = load_dataset(**load_dataset_kwargs)
94
  if not isinstance(dataset, (Dataset, IterableDataset)):
95
  raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
96
- dataset_converted = dataset.to_document_type(
97
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
99
  add_annotated_pie_documents(
100
  retriever=retriever,
101
- pie_documents=dataset_converted,
102
  use_predicted_annotations=False,
103
  verbose=verbose,
104
  )
 
12
  from pytorch_ie import Pipeline
13
  from pytorch_ie.documents import (
14
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
15
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
16
  )
17
  from tqdm import tqdm
18
 
19
+ from src.demo.annotation_utils import create_documents, get_merger
20
  from src.demo.data_utils import load_text_from_arxiv
21
  from src.demo.rendering_utils import (
22
  RENDER_WITH_DISPLACY,
 
55
  def process_texts(
56
  texts: Iterable[str],
57
  doc_ids: Iterable[str],
58
+ argumentation_model: Optional[Pipeline],
59
  retriever: DocumentAwareSpanRetriever,
60
  split_regex_escaped: Optional[str],
61
  handle_parts_of_same: bool = False,
 
69
  doc_ids=doc_ids,
70
  split_regex=split_regex_escaped,
71
  )
72
+ if argumentation_model is not None:
73
+ if verbose:
74
+ gr.Info(f"Annotate {len(pie_documents)} documents...")
75
+ pie_documents = argumentation_model(pie_documents, inplace=True)
76
+ else:
77
+ gr.Warning(
78
+ "Annotation is disabled (no model was loaded). No annotations will be added to the documents."
79
+ )
80
+
81
+ # this needs to be done also if the documents are not annotated because
82
+ # it adjusts the document type
83
+ if handle_parts_of_same:
84
+ merger = get_merger()
85
+ pie_documents = [merger(document) for document in pie_documents]
86
+
87
  add_annotated_pie_documents(
88
  retriever=retriever,
89
  pie_documents=pie_documents,
 
102
  dataset = load_dataset(**load_dataset_kwargs)
103
  if not isinstance(dataset, (Dataset, IterableDataset)):
104
  raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
105
+ try:
106
+ dataset_converted = dataset.to_document_type(
107
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
108
+ )
109
+ except ValueError:
110
+ gr.Warning(
111
+ "The dataset does not seem to have registered converter to create multi-spans. "
112
+ "Try to Load as single-spans and to convert to multi-spans manually ..."
113
+ )
114
+ dataset_converted_single_span = dataset.to_document_type(
115
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
116
+ )
117
+ merger = get_merger()
118
+ dataset_converted = dataset_converted_single_span.map(
119
+ merger,
120
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
121
+ )
122
+
123
+ def _clear_metadata(
124
+ doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
125
+ ) -> TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions:
126
+ result = doc.copy()
127
+ result.metadata = dict()
128
+ return result
129
+
130
+ # adding documents with different metadata format to the retriever breaks it,
131
+ # so we clear the metadata field beforehand
132
+ dataset_converted_without_metadata = dataset_converted.map(
133
+ _clear_metadata,
134
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
135
  )
136
+
137
  add_annotated_pie_documents(
138
  retriever=retriever,
139
+ pie_documents=dataset_converted_without_metadata,
140
  use_predicted_annotations=False,
141
  verbose=verbose,
142
  )
src/demo/retrieve_and_dump_all_relevant.py CHANGED
@@ -10,21 +10,24 @@ root = pyrootutils.setup_root(
10
  import argparse
11
  import logging
12
  import os
13
- from typing import Dict, List, Optional, Tuple
14
 
15
  import pandas as pd
16
  from pie_datasets import Dataset, DatasetDict
17
  from pytorch_ie import Annotation
18
  from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
19
 
20
- from document.types import (
21
- RelatedRelation,
22
- TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
23
- )
24
  from src.demo.retriever_utils import (
25
  retrieve_all_relevant_spans,
26
  retrieve_all_relevant_spans_for_all_documents,
 
 
27
  retrieve_relevant_spans,
 
 
 
 
 
28
  )
29
  from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
30
 
@@ -131,14 +134,17 @@ def add_result_to_gold_data(
131
  base_annotation_mapping=base_annotation_mapping,
132
  )
133
  )
134
- doc_and_span_id2annotation.update(
135
- get_doc_and_span_id2annotation_mapping(
136
- span_ids=result["ref_span_id"],
137
- doc_ids=result["doc_id"],
138
- retriever=retriever,
139
- base_annotation_mapping=base_annotation_mapping,
 
 
 
 
140
  )
141
- )
142
  doc_and_span_id2annotation.update(
143
  get_doc_and_span_id2annotation_mapping(
144
  span_ids=result["query_span_id"],
@@ -159,38 +165,51 @@ def add_result_to_gold_data(
159
  (row.query_doc_id, row.query_span_id)
160
  ]
161
  doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
162
- doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
163
  if doc_id != query_doc_id:
164
  raise ValueError("doc_id and query_doc_id must be the same")
165
- if doc_id != doc_id2:
166
- raise ValueError("doc_id and ref_doc_id must be the same")
167
  doc = doc_id2doc[doc_id]
168
- link_rel = BinaryRelation(
169
- head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
170
- )
171
- doc.binary_relations.predictions.append(link_rel)
172
- head_and_tail2relation = doc_id2head_tail2relation[doc_id]
173
- related_rel_label = row.type
174
- if related_rel_label.endswith(reversed_relation_suffix):
175
- base_rel = head_and_tail2relation[(span, ref_span)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  else:
177
- base_rel = head_and_tail2relation[(ref_span, span)]
178
- related_rel = RelatedRelation(
179
- head=query_span,
180
- tail=span,
181
- link_relation=link_rel,
182
- relation=base_rel,
183
- label=related_rel_label,
184
- score=link_rel.score * base_rel.score,
185
- )
186
- doc.related_relations.predictions.append(related_rel)
187
 
188
  dataset = Dataset.from_documents(list(doc_id2doc.values()))
189
  dataset_dict = DatasetDict({split: dataset})
190
  if not os.path.exists(dataset_out_dir):
191
  os.makedirs(dataset_out_dir, exist_ok=True)
192
 
193
- dataset_dict.to_json(dataset_out_dir)
194
 
195
 
196
  if __name__ == "__main__":
@@ -216,6 +235,13 @@ if __name__ == "__main__":
216
  type=str,
217
  required=True,
218
  )
 
 
 
 
 
 
 
219
  parser.add_argument(
220
  "--query_doc_id",
221
  type=str,
@@ -282,6 +308,24 @@ if __name__ == "__main__":
282
  logger.info(f"loading data from {args.data_path}...")
283
  retriever.load_from_disc(args.data_path)
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
286
  if args.doc_id_whitelist is not None:
287
  search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
@@ -293,7 +337,7 @@ if __name__ == "__main__":
293
  all_spans_for_all_documents = None
294
  for doc_id_pair in args.query_target_doc_id_pairs:
295
  query_doc_id, target_doc_id = doc_id_pair.split(":")
296
- current_result = retrieve_all_relevant_spans(
297
  retriever=retriever,
298
  query_doc_id=query_doc_id,
299
  doc_id_whitelist=[target_doc_id],
@@ -319,16 +363,16 @@ if __name__ == "__main__":
319
 
320
  elif args.query_span_id is not None:
321
  logger.warning(f"retrieving results for single span: {args.query_span_id}")
322
- all_spans_for_all_documents = retrieve_relevant_spans(
323
  retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
324
  )
325
  elif args.query_doc_id is not None:
326
  logger.warning(f"retrieving results for single document: {args.query_doc_id}")
327
- all_spans_for_all_documents = retrieve_all_relevant_spans(
328
  retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
329
  )
330
  else:
331
- all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
332
  retriever=retriever, **search_kwargs
333
  )
334
 
 
10
  import argparse
11
  import logging
12
  import os
13
+ from typing import Callable, Dict, List, Optional, Tuple
14
 
15
  import pandas as pd
16
  from pie_datasets import Dataset, DatasetDict
17
  from pytorch_ie import Annotation
18
  from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
19
 
 
 
 
 
20
  from src.demo.retriever_utils import (
21
  retrieve_all_relevant_spans,
22
  retrieve_all_relevant_spans_for_all_documents,
23
+ retrieve_all_similar_spans,
24
+ retrieve_all_similar_spans_for_all_documents,
25
  retrieve_relevant_spans,
26
+ retrieve_similar_spans,
27
+ )
28
+ from src.document.types import (
29
+ RelatedRelation,
30
+ TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
31
  )
32
  from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
33
 
 
134
  base_annotation_mapping=base_annotation_mapping,
135
  )
136
  )
137
+ # only when we process relevant span retriever results, we have a ref_span_id
138
+ # (for similar span retriever results, we only have query_span_id)
139
+ if "ref_span_id" in result.columns:
140
+ doc_and_span_id2annotation.update(
141
+ get_doc_and_span_id2annotation_mapping(
142
+ span_ids=result["ref_span_id"],
143
+ doc_ids=result["doc_id"],
144
+ retriever=retriever,
145
+ base_annotation_mapping=base_annotation_mapping,
146
+ )
147
  )
 
148
  doc_and_span_id2annotation.update(
149
  get_doc_and_span_id2annotation_mapping(
150
  span_ids=result["query_span_id"],
 
165
  (row.query_doc_id, row.query_span_id)
166
  ]
167
  doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
 
168
  if doc_id != query_doc_id:
169
  raise ValueError("doc_id and query_doc_id must be the same")
 
 
170
  doc = doc_id2doc[doc_id]
171
+
172
+ # if we have a reference span, we need to construct the related relation
173
+ if hasattr(row, "ref_span_id"):
174
+ doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
175
+ if doc_id != doc_id2:
176
+ raise ValueError("doc_id and ref_doc_id must be the same")
177
+
178
+ # create a link relation between the query span and the reference span
179
+ link_rel = BinaryRelation(
180
+ head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
181
+ )
182
+ doc.binary_relations.predictions.append(link_rel)
183
+
184
+ head_and_tail2relation = doc_id2head_tail2relation[doc_id]
185
+ related_rel_label = row.type
186
+ if related_rel_label.endswith(reversed_relation_suffix):
187
+ base_rel = head_and_tail2relation[(span, ref_span)]
188
+ else:
189
+ base_rel = head_and_tail2relation[(ref_span, span)]
190
+ related_rel = RelatedRelation(
191
+ head=query_span,
192
+ tail=span,
193
+ link_relation=link_rel,
194
+ relation=base_rel,
195
+ label=related_rel_label,
196
+ score=link_rel.score * base_rel.score,
197
+ )
198
+ doc.related_relations.predictions.append(related_rel)
199
+ # otherwise, we just ...
200
  else:
201
+ # ... create a link relation between the query span and returned span
202
+ link_rel = BinaryRelation(
203
+ head=query_span, tail=span, label=link_relation_label, score=row.sim_score
204
+ )
205
+ doc.binary_relations.predictions.append(link_rel)
 
 
 
 
 
206
 
207
  dataset = Dataset.from_documents(list(doc_id2doc.values()))
208
  dataset_dict = DatasetDict({split: dataset})
209
  if not os.path.exists(dataset_out_dir):
210
  os.makedirs(dataset_out_dir, exist_ok=True)
211
 
212
+ dataset_dict.to_json(dataset_out_dir, mode="w")
213
 
214
 
215
  if __name__ == "__main__":
 
235
  type=str,
236
  required=True,
237
  )
238
+ parser.add_argument(
239
+ "-v",
240
+ "--variant",
241
+ choices=["relevant", "similar"],
242
+ default="relevant",
243
+ help="Variant of the retriever to use: 'relevant' for relevant spans, 'similar' for similar spans.",
244
+ )
245
  parser.add_argument(
246
  "--query_doc_id",
247
  type=str,
 
308
  logger.info(f"loading data from {args.data_path}...")
309
  retriever.load_from_disc(args.data_path)
310
 
311
+ methods: Dict[str, Callable]
312
+ if args.variant == "relevant":
313
+ logger.info("using *relevant* span retriever methods")
314
+ methods = {
315
+ "retrieve_all_spans": retrieve_all_relevant_spans,
316
+ "retrieve_spans": retrieve_relevant_spans,
317
+ "retrieve_all_spans_for_all_documents": retrieve_all_relevant_spans_for_all_documents,
318
+ }
319
+ elif args.variant == "similar":
320
+ logger.info("using *similar* span retriever methods")
321
+ methods = {
322
+ "retrieve_all_spans": retrieve_all_similar_spans,
323
+ "retrieve_spans": retrieve_similar_spans,
324
+ "retrieve_all_spans_for_all_documents": retrieve_all_similar_spans_for_all_documents,
325
+ }
326
+ else:
327
+ raise ValueError(f"unknown method: {args.variant}")
328
+
329
  search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
330
  if args.doc_id_whitelist is not None:
331
  search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
 
337
  all_spans_for_all_documents = None
338
  for doc_id_pair in args.query_target_doc_id_pairs:
339
  query_doc_id, target_doc_id = doc_id_pair.split(":")
340
+ current_result = methods["retrieve_all_spans"](
341
  retriever=retriever,
342
  query_doc_id=query_doc_id,
343
  doc_id_whitelist=[target_doc_id],
 
363
 
364
  elif args.query_span_id is not None:
365
  logger.warning(f"retrieving results for single span: {args.query_span_id}")
366
+ all_spans_for_all_documents = methods["retrieve_spans"](
367
  retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
368
  )
369
  elif args.query_doc_id is not None:
370
  logger.warning(f"retrieving results for single document: {args.query_doc_id}")
371
+ all_spans_for_all_documents = methods["retrieve_all_spans"](
372
  retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
373
  )
374
  else:
375
+ all_spans_for_all_documents = methods["retrieve_all_spans_for_all_documents"](
376
  retriever=retriever, **search_kwargs
377
  )
378
 
src/document/processing.py CHANGED
@@ -1,20 +1,23 @@
1
  from __future__ import annotations
2
 
 
3
  import logging
4
  from collections import defaultdict
5
  from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
6
 
7
  from pie_modules.utils.span import have_overlap
8
  from pytorch_ie import AnnotationLayer
9
- from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
10
  from pytorch_ie.core import Document
11
  from pytorch_ie.core.document import Annotation, _enumerate_dependencies
 
12
 
13
  from src.document.types import (
14
  RelatedRelation,
15
  TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
16
  )
17
  from src.utils import distance, distance_slices
 
18
  from src.utils.span_utils import get_overlap_len
19
 
20
  logger = logging.getLogger(__name__)
@@ -123,6 +126,69 @@ def remove_partitions_by_labels(
123
  D_text = TypeVar("D_text", bound=Document)
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def replace_substrings_in_text(
127
  document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
128
  ) -> D_text:
@@ -512,3 +578,236 @@ def add_related_relations_from_binary_relations(
512
  )
513
 
514
  return document
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import itertools
4
  import logging
5
  from collections import defaultdict
6
  from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
7
 
8
  from pie_modules.utils.span import have_overlap
9
  from pytorch_ie import AnnotationLayer
10
+ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan, MultiSpan, Span
11
  from pytorch_ie.core import Document
12
  from pytorch_ie.core.document import Annotation, _enumerate_dependencies
13
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
14
 
15
  from src.document.types import (
16
  RelatedRelation,
17
  TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
18
  )
19
  from src.utils import distance, distance_slices
20
+ from src.utils.graph_utils import get_connected_components
21
  from src.utils.span_utils import get_overlap_len
22
 
23
  logger = logging.getLogger(__name__)
 
126
  D_text = TypeVar("D_text", bound=Document)
127
 
128
 
129
+ def remove_annotations_by_label(
130
+ document: D, layer2label_blacklist: Dict[str, List[str]], verbose: bool = False
131
+ ) -> D:
132
+ """Remove annotations with labels in the blacklist from a document.
133
+
134
+ Args:
135
+ document: The document to process.
136
+ layer2label_blacklist: A mapping from layer names to lists of labels to remove.
137
+ verbose: Whether to print number of removed annotations.
138
+
139
+ Returns:
140
+ The processed document.
141
+ """
142
+
143
+ result = document.copy(with_annotations=False)
144
+ override_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
145
+ removed_annotations: Dict[str, Set[int]] = defaultdict(set)
146
+ for layer_name, labels in layer2label_blacklist.items():
147
+ # process gold annotations and predictions
148
+ for src_layer, tgt_layer in [
149
+ (document[layer_name], result[layer_name]),
150
+ (document[layer_name].predictions, result[layer_name].predictions),
151
+ ]:
152
+ current_override_annotations = dict()
153
+ current_removed_annotations = set()
154
+ for annotation in src_layer:
155
+ label = getattr(annotation, "label")
156
+ if label is None:
157
+ raise ValueError(
158
+ f"Annotation {annotation} has no label. Please check the annotation type."
159
+ )
160
+ if label not in labels:
161
+ current_override_annotations[annotation._id] = annotation.copy()
162
+ else:
163
+ current_removed_annotations.add(annotation._id)
164
+ tgt_layer.extend(current_override_annotations.values())
165
+
166
+ override_annotations[layer_name].update(current_override_annotations)
167
+ removed_annotations[layer_name].update(current_removed_annotations)
168
+ if verbose:
169
+ num_removed = {
170
+ layer_name: len(removed_ids) for layer_name, removed_ids in removed_annotations.items()
171
+ }
172
+ if len(num_removed) > 0:
173
+ num_total = {
174
+ layer_name: len(kept_ids) + num_removed[layer_name]
175
+ for layer_name, kept_ids in override_annotations.items()
176
+ }
177
+ logger.warning(
178
+ f"doc.id={document.id}: Removed {num_removed} (total: {num_total}) "
179
+ f"annotations with label blacklists {layer2label_blacklist}"
180
+ )
181
+
182
+ result.add_all_annotations_from_other(
183
+ other=document,
184
+ removed_annotations=removed_annotations,
185
+ override_annotations=override_annotations,
186
+ strict=False,
187
+ verbose=False,
188
+ )
189
+ return result
190
+
191
+
192
  def replace_substrings_in_text(
193
  document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
194
  ) -> D_text:
 
578
  )
579
 
580
  return document
581
+
582
+
583
+ T = TypeVar("T", bound=TextDocumentWithLabeledSpansAndBinaryRelations)
584
+
585
+
586
+ def remove_discontinuous_spans(
587
+ document: T,
588
+ parts_of_same_relation: str,
589
+ verbose: bool = False,
590
+ ) -> T:
591
+ """
592
+ Remove discontinuous spans from a document.
593
+
594
+ Args:
595
+ document: The document to process.
596
+ parts_of_same_relation: The name of the relation that indicates linked spans.
597
+ verbose: Whether to print debug information.
598
+
599
+ Returns:
600
+ The processed document.
601
+ """
602
+ result = document.copy()
603
+ spans = result.labeled_spans.clear()
604
+ rels = result.binary_relations.clear()
605
+
606
+ segment_spans = set()
607
+ segment_rels = set()
608
+ # collect all spans that are linked
609
+ for rel in rels:
610
+ if rel.label == parts_of_same_relation:
611
+ segment_spans.add(rel.head)
612
+ segment_spans.add(rel.tail)
613
+ segment_rels.add(rel)
614
+
615
+ for span in spans:
616
+ if span not in segment_spans:
617
+ result.labeled_spans.append(span)
618
+
619
+ other_rels_dropped = set()
620
+ for rel in rels:
621
+ if rel not in segment_rels:
622
+ if rel.head not in segment_spans and rel.tail not in segment_spans:
623
+ result.binary_relations.append(rel)
624
+ else:
625
+ other_rels_dropped.add(rel)
626
+
627
+ if verbose:
628
+ if len(segment_rels) > 0:
629
+ logger.warning(
630
+ f"doc={document.id}: Dropped {len(segment_rels)} segment rels "
631
+ f"and {len(other_rels_dropped)} other rels "
632
+ f"({round((len(document.binary_relations) - len(result.binary_relations)) * 100 / len(document.binary_relations), 1)}% "
633
+ f"of all relations dropped)"
634
+ )
635
+ return result
636
+
637
+
638
+ def close_clusters_transitively(
639
+ document: D, relation_layer: str, link_relation_label: str, verbose: bool = False
640
+ ) -> D:
641
+ """
642
+ Close clusters transitively by adding relations between all pairs of spans in the same cluster.
643
+
644
+ Args:
645
+ document: The document to process.
646
+ relation_layer: The name of the relation layer.
647
+ link_relation_label: The label of the link relation.
648
+ verbose: Whether to print debug information.
649
+
650
+ Returns:
651
+ The processed document.
652
+ """
653
+ result = document.copy()
654
+
655
+ connected_components: List[List[Annotation]] = get_connected_components(
656
+ relations=result[relation_layer],
657
+ link_relation_label=link_relation_label,
658
+ add_singletons=False,
659
+ )
660
+ # detach from document
661
+ relations = result[relation_layer].clear()
662
+ # use set to speed up membership checks
663
+ relations_set = set(relations)
664
+ n_before = len(relations)
665
+ for cluster in connected_components:
666
+ for head, tail in itertools.combinations(sorted(cluster), 2):
667
+ rel = BinaryRelation(
668
+ head=head,
669
+ tail=tail,
670
+ label=link_relation_label,
671
+ )
672
+ rel_reversed = BinaryRelation(
673
+ head=tail,
674
+ tail=head,
675
+ label=link_relation_label,
676
+ )
677
+ if rel not in relations_set and rel_reversed not in relations_set:
678
+ # append to relations to keep the order
679
+ relations.append(rel)
680
+ relations_set.add(rel)
681
+
682
+ result[relation_layer].extend(relations)
683
+ if verbose:
684
+ num_added = len(relations) - n_before
685
+ if num_added > 0:
686
+ logger.warning(
687
+ f"doc.id={document.id}: added {num_added} relations to {relation_layer} layer"
688
+ )
689
+
690
+ return result
691
+
692
+
693
+ def get_ancestor_layers(children: Dict[str, Set[str]], layer: str) -> Set[str]:
694
+ """
695
+ Get all ancestor layers of a given layer in the dependency graph.
696
+
697
+ Args:
698
+ children: A mapping from layers to their children layers.
699
+ layer: The layer for which to find ancestors.
700
+
701
+ Returns:
702
+ A set of ancestor layers.
703
+ """
704
+ ancestors = set()
705
+
706
+ def _get_ancestors(current_layer: str):
707
+ for parent_layer, child_layers in children.items():
708
+ if current_layer in child_layers:
709
+ ancestors.add(parent_layer)
710
+ _get_ancestors(parent_layer)
711
+
712
+ _get_ancestors(layer)
713
+ # drop the _artificial_root
714
+ ancestors.discard("_artificial_root")
715
+ return ancestors
716
+
717
+
718
+ def remove_binary_relations_by_partition_labels(
719
+ document: D,
720
+ partition_layer: str,
721
+ relation_layer: str,
722
+ partition_label_whitelist: Optional[List[List[str]]] = None,
723
+ partition_label_blacklist: Optional[List[List[str]]] = None,
724
+ verbose: bool = False,
725
+ ) -> D:
726
+ """
727
+ Remove binary relations that are not between partitions with labels in the whitelist or
728
+ that are in the blacklist.
729
+
730
+ Args:
731
+ document: The document to process.
732
+ partition_layer: The name of the partition layer.
733
+ relation_layer: The name of the relation layer.
734
+ partition_label_whitelist: The list of head-tail label pairs to keep.
735
+ partition_label_blacklist: The list of head-tail label pairs to remove.
736
+ verbose: Whether to print the removed relations to console.
737
+
738
+ Returns:
739
+ The processed document.
740
+ """
741
+ result = document.copy()
742
+
743
+ relation_annotation_layer = result[relation_layer]
744
+ # get all layers that target the relation layer
745
+ relation_dependent_layers = get_ancestor_layers(
746
+ children=result._annotation_graph, layer=relation_layer
747
+ )
748
+ # clear all layers that depend on the relation layer
749
+ for layer_name in relation_dependent_layers:
750
+ dependent_layer = result[layer_name]
751
+ gold_anns_cleared = dependent_layer.clear()
752
+ pred_anns_cleared = dependent_layer.predictions.clear()
753
+ if len(gold_anns_cleared) > 0 or len(pred_anns_cleared) > 0:
754
+ if verbose:
755
+ logger.warning(
756
+ f"doc.id={document.id}: Cleared {len(gold_anns_cleared)} gold and "
757
+ f"{len(pred_anns_cleared)} predicted annotations from layer {layer_name} "
758
+ f"because it depends on the relation layer {relation_layer}."
759
+ )
760
+
761
+ span2partition = {}
762
+ span_layer: AnnotationLayer
763
+ for span_layer in relation_annotation_layer.target_layers.values():
764
+ for span in list(span_layer) + list(span_layer.predictions):
765
+ if isinstance(span, Span):
766
+ span_start, span_end = span.start, span.end
767
+ elif isinstance(span, MultiSpan):
768
+ span_start, span_end = min(start for start, _ in span.slices), max(
769
+ end for _, end in span.slices
770
+ )
771
+ else:
772
+ raise ValueError(f"Unsupported span type: {type(span)}")
773
+ found_partition = False
774
+ for partition in result[partition_layer]:
775
+ if partition.start <= span_start and span_end <= partition.end:
776
+ span2partition[span] = partition
777
+ found_partition = True
778
+ break
779
+ if not found_partition:
780
+ raise ValueError(f"No partition found for span {span}")
781
+
782
+ if partition_label_whitelist is not None:
783
+ partition_label_whitelist_tuples = [tuple(pair) for pair in partition_label_whitelist]
784
+ else:
785
+ partition_label_whitelist_tuples = None
786
+ if partition_label_blacklist is not None:
787
+ partition_label_blacklist_tuples = [tuple(pair) for pair in partition_label_blacklist]
788
+ else:
789
+ partition_label_blacklist_tuples = None
790
+
791
+ for relation_base_layer in [relation_annotation_layer, relation_annotation_layer.predictions]:
792
+ # get all relations and clear the layer
793
+ relations = relation_base_layer.clear()
794
+ for relation in relations:
795
+ head_partition = span2partition[relation.head]
796
+ tail_partition = span2partition[relation.tail]
797
+ pair = (head_partition.label, tail_partition.label)
798
+ if (
799
+ partition_label_whitelist_tuples is None
800
+ or pair in partition_label_whitelist_tuples
801
+ ) and (
802
+ partition_label_blacklist_tuples is None
803
+ or pair not in partition_label_blacklist_tuples
804
+ ):
805
+ relation_base_layer.append(relation)
806
+ else:
807
+ if verbose:
808
+ logger.info(
809
+ f"Removing relation {relation} because its partitions "
810
+ f"({pair}) are not in the whitelist or are in the blacklist."
811
+ )
812
+
813
+ return result
src/evaluate.py CHANGED
@@ -41,13 +41,13 @@ from omegaconf import DictConfig
41
  from pie_datasets import DatasetDict
42
  from pie_modules.models import * # noqa: F403
43
  from pie_modules.taskmodules import * # noqa: F403
 
44
  from pytorch_ie.core import PyTorchIEModel, TaskModule
45
  from pytorch_ie.models import * # noqa: F403
46
  from pytorch_ie.taskmodules import * # noqa: F403
47
  from pytorch_lightning import Trainer
48
 
49
  from src import utils
50
- from src.datamodules import PieDataModule
51
  from src.models import * # noqa: F403
52
  from src.taskmodules import * # noqa: F403
53
 
@@ -80,8 +80,8 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
80
  log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
81
  taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
82
 
83
- # auto-convert the dataset if the metric specifies a document type
84
- dataset = taskmodule.convert_dataset(dataset)
85
 
86
  # Init pytorch-ie datamodule
87
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
 
41
  from pie_datasets import DatasetDict
42
  from pie_modules.models import * # noqa: F403
43
  from pie_modules.taskmodules import * # noqa: F403
44
+ from pytorch_ie import PieDataModule
45
  from pytorch_ie.core import PyTorchIEModel, TaskModule
46
  from pytorch_ie.models import * # noqa: F403
47
  from pytorch_ie.taskmodules import * # noqa: F403
48
  from pytorch_lightning import Trainer
49
 
50
  from src import utils
 
51
  from src.models import * # noqa: F403
52
  from src.taskmodules import * # noqa: F403
53
 
 
80
  log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
81
  taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
82
 
83
+ # auto-convert the dataset if the taskmodule specifies a document type
84
+ dataset = dataset.to_document_type(taskmodule, downcast=False)
85
 
86
  # Init pytorch-ie datamodule
87
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
src/evaluate_documents.py CHANGED
@@ -73,7 +73,7 @@ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
73
  metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
74
 
75
  # auto-convert the dataset if the metric specifies a document type
76
- dataset = metric.convert_dataset(dataset)
77
 
78
  # Init lightning loggers
79
  loggers = utils.instantiate_dict_entries(cfg, "logger")
 
73
  metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
74
 
75
  # auto-convert the dataset if the metric specifies a document type
76
+ dataset = dataset.to_document_type(metric, downcast=False)
77
 
78
  # Init lightning loggers
79
  loggers = utils.instantiate_dict_entries(cfg, "logger")
src/hydra_callbacks/save_job_return_value.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import os
4
  import pickle
5
  from pathlib import Path
6
- from typing import Any, Dict, Generator, List, Optional, Tuple, Union
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -174,6 +174,46 @@ def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "
174
  return identifiers
175
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  class SaveJobReturnValueCallback(Callback):
178
  """Save the job return-value in ${output_dir}/{job_return_value_filename}.
179
 
@@ -200,6 +240,10 @@ class SaveJobReturnValueCallback(Callback):
200
  multirun_create_ids_from_overrides: bool (default: True)
201
  Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
202
  identifier.
 
 
 
 
203
  markdown_round_digits: int (default: 3)
204
  The number of digits to round the values in the markdown file. If None, no rounding is applied.
205
  multirun_job_id_key: str (default: "job_id")
@@ -220,6 +264,8 @@ class SaveJobReturnValueCallback(Callback):
220
  integrate_multirun_result: bool = False,
221
  multirun_aggregator_blacklist: Optional[List[str]] = None,
222
  multirun_create_ids_from_overrides: bool = True,
 
 
223
  markdown_round_digits: Optional[int] = 3,
224
  multirun_job_id_key: str = "job_id",
225
  paths_file: Optional[str] = None,
@@ -234,6 +280,8 @@ class SaveJobReturnValueCallback(Callback):
234
  self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
235
  self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
236
  self.multirun_job_id_key = multirun_job_id_key
 
 
237
  self.markdown_round_digits = markdown_round_digits
238
  self.multirun_paths_file = multirun_paths_file
239
  self.multirun_path_id = multirun_path_id
@@ -253,10 +301,21 @@ class SaveJobReturnValueCallback(Callback):
253
 
254
  def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
255
  job_ids: Union[List[str], List[int]]
256
- if self.multirun_create_ids_from_overrides:
257
- job_ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
 
 
 
 
 
 
258
  else:
259
- job_ids = list(range(len(self.job_returns)))
 
 
 
 
 
260
 
261
  if self.integrate_multirun_result:
262
  # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
@@ -368,6 +427,10 @@ class SaveJobReturnValueCallback(Callback):
368
  if job_id_column in result.columns:
369
  result = result.set_index(job_id_column)
370
  result.index.name = self.multirun_job_id_key
 
 
 
 
371
  else:
372
  # Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
373
  series = pd.Series(obj_py_flat)
 
3
  import os
4
  import pickle
5
  from pathlib import Path
6
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
7
 
8
  import numpy as np
9
  import pandas as pd
 
174
  return identifiers
175
 
176
 
177
+ def identifier2dict(
178
+ identifier: str, record_sep: str = "-", key_value_sep: str = "="
179
+ ) -> Dict[str, str]:
180
+ """Converts a single identifier to a dict. The identifier is expected to be separated by "-".
181
+ Values are allowed to contain "-" as well, but keys are not. Key and value are separated by "=".
182
+
183
+ Example:
184
+ >>> identifier = "a=1-b=my-stuff"
185
+ >>> identifier2dict(identifier)
186
+ {'a': '1', 'b': 'my-stuff'}
187
+ """
188
+ parts = identifier.split(record_sep)
189
+ result = {}
190
+ last_key = None
191
+ for part in parts:
192
+ if key_value_sep in part:
193
+ last_key, value = part.split(key_value_sep, 1)
194
+ result[last_key] = value
195
+ else:
196
+ if last_key is None:
197
+ raise ValueError(
198
+ f'Invalid identifier: {identifier} (keys must not contain the record_sep="{record_sep}")'
199
+ )
200
+ result[last_key] += record_sep + part
201
+ return result
202
+
203
+
204
+ def identifiers_to_multiindex(identifiers: Iterable[str], **kwargs) -> pd.MultiIndex:
205
+ """Converts a list of identifiers to a MultiIndex. See identifier2dict for the
206
+ format of the identifiers.
207
+
208
+ Example:
209
+ >>> identifiers = ["a=1-b=my-stuff", "a=2-b=yes", "a=3"]
210
+ >>> identifiers_to_multiindex(identifiers, record_sep="-", key_value_sep="=")
211
+ MultiIndex([(1, 'my-stuff'), (2, 'yes'), (3, nan)], names=['a', 'b'])
212
+ """
213
+ frame = pd.DataFrame([identifier2dict(identifier, **kwargs) for identifier in identifiers])
214
+ return pd.MultiIndex.from_frame(frame)
215
+
216
+
217
  class SaveJobReturnValueCallback(Callback):
218
  """Save the job return-value in ${output_dir}/{job_return_value_filename}.
219
 
 
240
  multirun_create_ids_from_overrides: bool (default: True)
241
  Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
242
  identifier.
243
+ multirun_ids: List[str] or List[int] (default: None)
244
+ If provided, the job identifiers from the config are used instead of the overrides or the job index.
245
+ markdown_split_index: bool (default: False)
246
+ If True, the index of the markdown file is split into multiple columns based on the separator "-".
247
  markdown_round_digits: int (default: 3)
248
  The number of digits to round the values in the markdown file. If None, no rounding is applied.
249
  multirun_job_id_key: str (default: "job_id")
 
264
  integrate_multirun_result: bool = False,
265
  multirun_aggregator_blacklist: Optional[List[str]] = None,
266
  multirun_create_ids_from_overrides: bool = True,
267
+ multirun_ids: Optional[Union[List[str], List[int]]] = None,
268
+ markdown_split_index: bool = False,
269
  markdown_round_digits: Optional[int] = 3,
270
  multirun_job_id_key: str = "job_id",
271
  paths_file: Optional[str] = None,
 
280
  self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
281
  self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
282
  self.multirun_job_id_key = multirun_job_id_key
283
+ self.multirun_ids = multirun_ids
284
+ self.markdown_split_index = markdown_split_index
285
  self.markdown_round_digits = markdown_round_digits
286
  self.multirun_paths_file = multirun_paths_file
287
  self.multirun_path_id = multirun_path_id
 
301
 
302
  def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
303
  job_ids: Union[List[str], List[int]]
304
+ if self.multirun_ids is not None:
305
+ # use the job_ids from the config
306
+ if len(self.multirun_ids) != len(self.job_returns):
307
+ raise ValueError(
308
+ f"Number of job_ids ({len(self.multirun_ids)}) does not match number of job returns ({len(self.job_returns)})"
309
+ )
310
+ # convert ListConfig to list
311
+ job_ids = list(self.multirun_ids) # type: ignore
312
  else:
313
+ if self.multirun_create_ids_from_overrides:
314
+ job_ids = overrides_to_identifiers(
315
+ [jr.overrides for jr in self.job_returns], sep="-"
316
+ )
317
+ else:
318
+ job_ids = list(range(len(self.job_returns)))
319
 
320
  if self.integrate_multirun_result:
321
  # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
 
427
  if job_id_column in result.columns:
428
  result = result.set_index(job_id_column)
429
  result.index.name = self.multirun_job_id_key
430
+
431
+ if self.markdown_split_index:
432
+ result.index = identifiers_to_multiindex(result.index, record_sep="-")
433
+ result = result.reset_index()
434
  else:
435
  # Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
436
  series = pd.Series(obj_py_flat)
src/langchain_modules/basic_pie_document_store.py CHANGED
@@ -52,6 +52,7 @@ class BasicPieDocumentStore(PieDocumentStore):
52
  shutil.rmtree(pie_documents_path)
53
  os.makedirs(pie_documents_path, exist_ok=True)
54
  doc_ids_iter = iter(self.client.yield_keys())
 
55
  while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
56
  all_doc_ids.extend(batch_doc_ids)
57
  docs = self.client.mget(batch_doc_ids)
@@ -63,7 +64,8 @@ class BasicPieDocumentStore(PieDocumentStore):
63
  {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
64
  )
65
  pie_dataset = Dataset.from_documents(pie_docs)
66
- DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
 
67
  if len(all_doc_ids) > 0:
68
  doc_ids_path = os.path.join(path, "doc_ids.json")
69
  with open(doc_ids_path, "w") as f:
 
52
  shutil.rmtree(pie_documents_path)
53
  os.makedirs(pie_documents_path, exist_ok=True)
54
  doc_ids_iter = iter(self.client.yield_keys())
55
+ mode = "w"
56
  while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
57
  all_doc_ids.extend(batch_doc_ids)
58
  docs = self.client.mget(batch_doc_ids)
 
64
  {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
65
  )
66
  pie_dataset = Dataset.from_documents(pie_docs)
67
+ DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path, mode=mode)
68
+ mode = "a" # append after the first batch
69
  if len(all_doc_ids) > 0:
70
  doc_ids_path = os.path.join(path, "doc_ids.json")
71
  with open(doc_ids_path, "w") as f:
src/langchain_modules/datasets_pie_document_store.py CHANGED
@@ -118,7 +118,7 @@ class DatasetsPieDocumentStore(PieDocumentStore):
118
  logger.warning(f"Removing existing directory: {pie_documents_path}")
119
  shutil.rmtree(pie_documents_path)
120
  os.makedirs(pie_documents_path, exist_ok=True)
121
- DatasetDict({"train": self._data}).to_json(pie_documents_path)
122
  doc_ids_path = os.path.join(path, "doc_ids.json")
123
  with open(doc_ids_path, "w") as f:
124
  json.dump(all_doc_ids, f)
 
118
  logger.warning(f"Removing existing directory: {pie_documents_path}")
119
  shutil.rmtree(pie_documents_path)
120
  os.makedirs(pie_documents_path, exist_ok=True)
121
+ DatasetDict({"train": self._data}).to_json(pie_documents_path, mode="w")
122
  doc_ids_path = os.path.join(path, "doc_ids.json")
123
  with open(doc_ids_path, "w") as f:
124
  json.dump(all_doc_ids, f)
src/metrics/__init__.py CHANGED
@@ -1,3 +1,9 @@
1
- from .coref_sklearn import CorefMetricsSKLearn
 
 
2
  from .coref_torchmetrics import CorefMetricsTorchmetrics
 
 
3
  from .score_distribution import ScoreDistribution
 
 
 
1
+ from .connected_component_sizes import ConnectedComponentSizes
2
+ from .coref import CorefHoiF1
3
+ from .coref_sklearn import BinaryClassificationMetricsSKLearn
4
  from .coref_torchmetrics import CorefMetricsTorchmetrics
5
+ from .f1_with_threshold import F1WithThresholdMetric
6
+ from .ranking_sklearn import RankingMetricsSKLearn
7
  from .score_distribution import ScoreDistribution
8
+ from .semantically_same_ranking import SemanticallySameRankingMetric
9
+ from .tpfpfn import TPFFPFNMetric
src/metrics/connected_component_sizes.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import Counter
3
+ from typing import Dict, List, TypeVar
4
+
5
+ from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentStatistic
6
+ from pytorch_ie.annotations import BinaryRelation
7
+
8
+ from src.utils.graph_utils import get_connected_components
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ A = TypeVar("A")
13
+
14
+
15
+ # TODO: remove when "counts" aggregation function is available in DocumentStatistic
16
+ def count_func(values: List[int]) -> Dict[int, int]:
17
+ """Counts the number of occurrences of each value in the list."""
18
+ counter = Counter(values)
19
+ result = {k: counter[k] for k in sorted(counter)}
20
+ return result
21
+
22
+
23
+ class ConnectedComponentSizes(DocumentStatistic):
24
+ # TODO: use "counts" aggregation function when available in DocumentStatistic
25
+ DEFAULT_AGGREGATION_FUNCTIONS = ["src.metrics.connected_component_sizes.count_func"]
26
+
27
+ def __init__(self, relation_layer: str, link_relation_label: str, **kwargs) -> None:
28
+ super().__init__(**kwargs)
29
+ self.relation_layer = relation_layer
30
+ self.link_relation_label = link_relation_label
31
+
32
+ def _collect(self, document: Document) -> List[int]:
33
+ relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer]
34
+ spans: AnnotationLayer[Annotation] = document[self.relation_layer].target_layer
35
+
36
+ connected_components: List[List] = get_connected_components(
37
+ elements=spans,
38
+ relations=relations,
39
+ link_relation_label=self.link_relation_label,
40
+ add_singletons=True,
41
+ )
42
+ new_component_sizes = [len(component) for component in connected_components]
43
+ return new_component_sizes
src/metrics/coref.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from typing import Dict, Hashable, List, Optional, Sequence, Tuple, TypeVar
3
+
4
+ import numpy as np
5
+ from pytorch_ie import Annotation, Document, DocumentMetric
6
+ from pytorch_ie.annotations import BinaryRelation
7
+
8
+ from src.utils.graph_utils import get_connected_components
9
+
10
+
11
+ class CorefHoiEvaluator(object):
12
+ def __init__(self, metric, beta=1):
13
+ self.p_num = 0
14
+ self.p_den = 0
15
+ self.r_num = 0
16
+ self.r_den = 0
17
+ self.metric = metric
18
+ self.beta = beta
19
+
20
+ def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
21
+ if self.metric == ceafe_simplified:
22
+ pn, pd, rn, rd = self.metric(predicted, gold)
23
+ else:
24
+ pn, pd = self.metric(predicted, mention_to_gold)
25
+ rn, rd = self.metric(gold, mention_to_predicted)
26
+ self.p_num += pn
27
+ self.p_den += pd
28
+ self.r_num += rn
29
+ self.r_den += rd
30
+
31
+ def f1(self, p_num, p_den, r_num, r_den, beta=1):
32
+ p = 0 if p_den == 0 else p_num / float(p_den)
33
+ r = 0 if r_den == 0 else r_num / float(r_den)
34
+ return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
35
+
36
+ def get_f1(self):
37
+ return self.f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
38
+
39
+ def get_recall(self):
40
+ return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
41
+
42
+ def get_precision(self):
43
+ return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
44
+
45
+ def get_prf(self):
46
+ return self.get_precision(), self.get_recall(), self.get_f1()
47
+
48
+ def get_counts(self):
49
+ return self.p_num, self.p_den, self.r_num, self.r_den
50
+
51
+
52
+ def b_cubed_simplified(clusters, mention_to_gold):
53
+ num, dem = 0, 0
54
+ for c in clusters:
55
+ if len(c) == 1:
56
+ continue
57
+
58
+ gold_counts = Counter()
59
+ correct = 0
60
+ for m in c:
61
+ if m in mention_to_gold:
62
+ gold_counts[tuple(mention_to_gold[m])] += 1
63
+ for c2, count in gold_counts.items():
64
+ if len(c2) != 1:
65
+ correct += count * count
66
+
67
+ num += correct / float(len(c))
68
+ dem += len(c)
69
+ return num, dem
70
+
71
+
72
+ def muc_simplified(clusters, mention_to_gold):
73
+ tp, p = 0, 0
74
+ for c in clusters:
75
+ p += len(c) - 1
76
+ tp += len(c)
77
+ linked = set()
78
+ for m in c:
79
+ if m in mention_to_gold:
80
+ linked.add(mention_to_gold[m])
81
+ else:
82
+ tp -= 1
83
+ tp -= len(linked)
84
+ return tp, p
85
+
86
+
87
+ def phi4_simplified(c1, c2):
88
+ return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
89
+
90
+
91
+ def ceafe_simplified(clusters, gold_clusters):
92
+ # lazy import to not force scipy installation
93
+ from scipy.optimize import linear_sum_assignment as linear_assignment
94
+
95
+ clusters = [c for c in clusters if len(c) != 1]
96
+ scores = np.zeros((len(gold_clusters), len(clusters)))
97
+ for i in range(len(gold_clusters)):
98
+ for j in range(len(clusters)):
99
+ scores[i, j] = phi4_simplified(gold_clusters[i], clusters[j])
100
+ matching = linear_assignment(-scores)
101
+ matching = np.transpose(np.asarray(matching))
102
+ similarity = sum(scores[matching[:, 0], matching[:, 1]])
103
+ return similarity, len(clusters), similarity, len(gold_clusters)
104
+
105
+
106
+ def lea_simplified(clusters, mention_to_gold):
107
+ num, dem = 0, 0
108
+
109
+ for c in clusters:
110
+ if len(c) == 1:
111
+ continue
112
+
113
+ common_links = 0
114
+ all_links = len(c) * (len(c) - 1) / 2.0
115
+ for i, m in enumerate(c):
116
+ if m in mention_to_gold:
117
+ for m2 in c[i + 1 :]:
118
+ if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]:
119
+ common_links += 1
120
+
121
+ num += len(c) * common_links / float(all_links)
122
+ dem += len(c)
123
+
124
+ return num, dem
125
+
126
+
127
+ H = TypeVar("H", bound=Hashable)
128
+
129
+
130
+ class CorefHoiF1(DocumentMetric):
131
+ """
132
+ Coreference evaluation based on official coref-hoi evaluation script, i.e.,
133
+ https://github.com/lxucs/coref-hoi/blob/5ddfc3b64a5519c3555b5a57e47ab2f03c104a60/metrics.py.
134
+
135
+ The metric expects documents with a relation layer that contains binary relations
136
+ between mentions from the same coreference cluster. Works with relations targeting
137
+ mentions from multiple layers (e.g., cross-textual relations).
138
+
139
+ Args:
140
+ relation_layer: The name of the relation layer that contains the link relations.
141
+ include_singletons: If True (default), singletons will be included in the evaluation.
142
+ link_relation_label: If provided, only the relations with this label will be used
143
+ to create the clusters.
144
+ link_relation_relation_score_threshold: If provided, only the relations with a score
145
+ greater than or equal to this threshold will be used to create the clusters.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ relation_layer: str,
151
+ include_singletons: bool = True,
152
+ link_relation_label: Optional[str] = None,
153
+ link_relation_relation_score_threshold: Optional[float] = None,
154
+ ) -> None:
155
+ super().__init__()
156
+ self.relation_layer = relation_layer
157
+ self.link_relation_label = link_relation_label
158
+ self.include_singletons = include_singletons
159
+ self.link_relation_relation_score_threshold = link_relation_relation_score_threshold
160
+
161
+ def reset(self) -> None:
162
+ self.evaluators = [
163
+ CorefHoiEvaluator(m) for m in (muc_simplified, b_cubed_simplified, ceafe_simplified)
164
+ ]
165
+
166
+ def prepare_clusters_with_mapping(
167
+ self, mentions: Sequence[Annotation], relations: Sequence[BinaryRelation]
168
+ ) -> Tuple[List[List[Annotation]], Dict[Annotation, Tuple[Annotation]]]:
169
+
170
+ # get connected components based on binary relations
171
+ connected_components = get_connected_components(
172
+ elements=mentions,
173
+ relations=relations,
174
+ link_relation_label=self.link_relation_label,
175
+ link_relation_relation_score_threshold=self.link_relation_relation_score_threshold,
176
+ add_singletons=self.include_singletons,
177
+ )
178
+
179
+ # store all clustered mentions in a list and
180
+ # create a map from each mention to its cluster
181
+ # (i.e. to the list of spans that includes all other mentions from the same cluster)
182
+ clusters = []
183
+ mention_to_cluster = dict()
184
+ for cluster in connected_components:
185
+ clusters.append(cluster)
186
+ for mention in cluster:
187
+ mention_to_cluster[mention] = tuple(cluster)
188
+
189
+ return clusters, mention_to_cluster
190
+
191
+ def _update(self, doc: Document) -> None:
192
+ relation_layer = doc[self.relation_layer]
193
+ gold_mentions = []
194
+ predicted_mentions = []
195
+ for mention_layer in relation_layer.target_layers.values():
196
+ gold_mentions.extend(mention_layer)
197
+ predicted_mentions.extend(mention_layer.predictions)
198
+
199
+ # prepare the clusters and mention-to-cluster mapping needed for evaluation
200
+ predicted_clusters, mention_to_predicted = self.prepare_clusters_with_mapping(
201
+ mentions=predicted_mentions, relations=relation_layer.predictions
202
+ )
203
+ gold_clusters, mention_to_gold = self.prepare_clusters_with_mapping(
204
+ mentions=gold_mentions, relations=relation_layer
205
+ )
206
+
207
+ for e in self.evaluators:
208
+ e.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
209
+
210
+ def get_f1(self) -> float:
211
+ return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
212
+
213
+ def get_recall(self) -> float:
214
+ return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
215
+
216
+ def get_precision(self) -> float:
217
+ return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
218
+
219
+ def get_prf(self) -> Tuple[float, float, float]:
220
+ return self.get_precision(), self.get_recall(), self.get_f1()
221
+
222
+ def _compute(self) -> float:
223
+ return self.get_f1()
src/metrics/coref_sklearn.py CHANGED
@@ -1,15 +1,13 @@
1
  import logging
2
  import math
3
- from typing import Any, Callable, Dict, List, Optional, Union
4
 
5
  import numpy as np
6
- import torch
7
  from pandas import MultiIndex
8
- from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
9
- from pytorch_ie import DocumentMetric
10
  from pytorch_ie.core.metric import T
11
  from pytorch_ie.utils.hydra import resolve_target
12
- from torchmetrics import Metric, MetricCollection
13
 
14
  from src.hydra_callbacks.save_job_return_value import to_py_obj
15
 
@@ -24,6 +22,14 @@ def get_num_positives(targets: List[int], preds: List[float], positive_idx: int
24
  return len([v for v in targets if v == positive_idx])
25
 
26
 
 
 
 
 
 
 
 
 
27
  def discretize(
28
  values: List[float], threshold: Union[float, List[float], dict]
29
  ) -> Union[List[float], Dict[Any, List[float]]]:
@@ -40,20 +46,97 @@ def discretize(
40
  raise TypeError(f"threshold has unknown type: {threshold}")
41
 
42
 
43
- class CorefMetricsSKLearn(DocumentMetric):
44
- DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def __init__(
47
  self,
48
  metrics: Dict[str, str],
 
 
49
  thresholds: Optional[Dict[str, float]] = None,
50
  default_target_idx: int = 0,
51
  default_prediction_score: float = 0.0,
52
  show_as_markdown: bool = False,
53
  markdown_precision: int = 4,
54
- plot: bool = False,
 
 
 
 
 
55
  ):
56
- self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
57
  self.thresholds = thresholds or {}
58
  thresholds_not_in_metrics = {
59
  name: t for name, t in self.thresholds.items() if name not in self.metrics
@@ -62,11 +145,25 @@ class CorefMetricsSKLearn(DocumentMetric):
62
  logger.warning(
63
  f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
64
  )
 
 
65
  self.default_target_idx = default_target_idx
66
  self.default_prediction_score = default_prediction_score
67
  self.show_as_markdown = show_as_markdown
68
  self.markdown_precision = markdown_precision
69
- self.plot = plot
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  super().__init__()
72
 
@@ -74,50 +171,55 @@ class CorefMetricsSKLearn(DocumentMetric):
74
  self._preds: List[float] = []
75
  self._targets: List[int] = []
76
 
77
- def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
78
- target_args2idx = {
79
- (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
 
 
 
80
  }
81
- prediction_args2score = {
82
- (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
 
 
83
  }
84
- all_args = set(target_args2idx) | set(prediction_args2score)
85
  all_targets: List[int] = []
86
  all_predictions: List[float] = []
87
  for args in all_args:
88
- target_idx = target_args2idx.get(args, self.default_target_idx)
89
- prediction_score = prediction_args2score.get(args, self.default_prediction_score)
90
  all_targets.append(target_idx)
91
  all_predictions.append(prediction_score)
92
- # prediction_scores = torch.tensor(all_predictions)
93
- # target_indices = torch.tensor(all_targets)
94
- # self.metrics.update(preds=prediction_scores, target=target_indices)
95
  self._preds.extend(all_predictions)
96
  self._targets.extend(all_targets)
97
 
98
- def do_plot(self):
99
- raise NotImplementedError()
100
 
101
  from matplotlib import pyplot as plt
102
 
103
  # Get the number of metrics
104
- num_metrics = len(self.metrics)
105
 
106
  # Calculate rows and columns for subplots (aim for a square-like layout)
107
- ncols = math.ceil(math.sqrt(num_metrics))
108
- nrows = math.ceil(num_metrics / ncols)
109
 
110
  # Create the subplots
111
  fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
112
 
113
  # Flatten the ax_list if necessary (in case of multiple rows/columns)
114
- ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
 
 
 
115
 
116
- # Ensure that we pass exactly the number of axes required by metrics
117
- ax_list = ax_list[:num_metrics]
118
-
119
- # Plot the metrics using the list of axes
120
- self.metrics.plot(ax=ax_list, together=False)
121
 
122
  # Adjust layout to avoid overlapping plots
123
  plt.tight_layout()
@@ -125,23 +227,35 @@ class CorefMetricsSKLearn(DocumentMetric):
125
 
126
  def _compute(self) -> T:
127
 
128
- if self.plot:
129
- self.do_plot()
130
 
131
  result = {}
132
  for name, metric in self.metrics.items():
133
 
134
  if name in self.thresholds:
135
- preds = discretize(values=self._preds, threshold=self.thresholds[name])
 
 
 
 
 
 
 
 
 
 
136
  else:
137
  preds = self._preds
138
- if isinstance(preds, dict):
139
- metric_results = {
140
- t: metric(self._targets, t_preds) for t, t_preds in preds.items()
141
- }
142
- # just get the max
143
- max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
144
- result[f"{name}-{max_t}"] = max_v
 
 
145
  else:
146
  result[name] = metric(self._targets, preds)
147
 
@@ -149,7 +263,8 @@ class CorefMetricsSKLearn(DocumentMetric):
149
  if self.show_as_markdown:
150
  import pandas as pd
151
 
152
- series = pd.Series(result)
 
153
  if isinstance(series.index, MultiIndex):
154
  if len(series.index.levels) > 1:
155
  # in fact, this is not a series anymore
 
1
  import logging
2
  import math
3
+ from typing import Any, Callable, Dict, List, Optional, Union, overload
4
 
5
  import numpy as np
 
6
  from pandas import MultiIndex
7
+ from pie_modules.utils import flatten_dict
8
+ from pytorch_ie import Document, DocumentMetric
9
  from pytorch_ie.core.metric import T
10
  from pytorch_ie.utils.hydra import resolve_target
 
11
 
12
  from src.hydra_callbacks.save_job_return_value import to_py_obj
13
 
 
22
  return len([v for v in targets if v == positive_idx])
23
 
24
 
25
+ @overload
26
+ def discretize(values: List[float], threshold: float) -> List[float]: ...
27
+
28
+
29
+ @overload
30
+ def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ...
31
+
32
+
33
  def discretize(
34
  values: List[float], threshold: Union[float, List[float], dict]
35
  ) -> Union[List[float], Dict[Any, List[float]]]:
 
46
  raise TypeError(f"threshold has unknown type: {threshold}")
47
 
48
 
49
+ def get_metric_func(name: str) -> Callable:
50
+ if name.endswith("_curve"):
51
+ from sklearn.metrics import auc
52
+
53
+ base_func = resolve_target(name)
54
+
55
+ def wrapper(targets: List[int], preds: List[float], **kwargs):
56
+ x, y, thresholds = base_func(targets, preds, **kwargs)
57
+ return auc(y, x)
58
+
59
+ return wrapper
60
+ else:
61
+ return resolve_target(name)
62
+
63
+
64
+ def bootstrap(
65
+ metric_fn: Callable[[List[int], Union[List[int], List[float]]], float],
66
+ targets: List[int],
67
+ predictions: Union[List[int], List[float]],
68
+ n: int = 1_000,
69
+ random_state: int | None = None,
70
+ alpha: float = 0.95,
71
+ ) -> Dict[str, float]:
72
+ """
73
+ Returns mean and a two–sided (1–alpha) bootstrap CI for any
74
+ pair-wise classification or ranking metric.
75
+
76
+ Parameters
77
+ ----------
78
+ metric_fn Metric function taking (targets, prediction) lists.
79
+ targets Ground-truth 0/1 labels.
80
+ prediction Scores or hard predictions (same length as `targets`).
81
+ n Number of bootstrap replicates (after skipping degenerate ones).
82
+ random_state Seed for reproducibility.
83
+ alpha Confidence level (default 0.95 → 95 % CI).
84
+
85
+ Notes
86
+ -----
87
+ * A replicate that contains only one class is discarded
88
+ because many sklearn metrics are undefined in that case.
89
+ * If all replicates are discarded an exception is raised.
90
+ """
91
+ y = np.asarray(targets)
92
+ yhat = np.asarray(predictions)
93
+ if y.shape[0] != yhat.shape[0]:
94
+ raise ValueError("`targets` and `prediction` must have the same length")
95
+
96
+ rng = np.random.default_rng(random_state)
97
+ idx = np.arange(y.shape[0])
98
+ vals_list: list[float] = []
99
+
100
+ while len(vals_list) < n:
101
+ sample_idx = rng.choice(idx, size=idx.shape[0], replace=True)
102
+ y_samp, yhat_samp = y[sample_idx], yhat[sample_idx]
103
+
104
+ # skip all-positive or all-negative bootstrap samples
105
+ if y_samp.min() == y_samp.max():
106
+ continue
107
+
108
+ vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist()))
109
+
110
+ if not vals_list:
111
+ raise RuntimeError("No valid bootstrap replicate contained both classes.")
112
+
113
+ vals = np.asarray(vals_list, dtype=float)
114
+ lower = np.percentile(vals, (1 - alpha) / 2 * 100)
115
+ upper = np.percentile(vals, (1 + alpha) / 2 * 100)
116
+
117
+ return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)}
118
+
119
+
120
+ class BinaryClassificationMetricsSKLearn(DocumentMetric):
121
 
122
  def __init__(
123
  self,
124
  metrics: Dict[str, str],
125
+ layer: str,
126
+ label: Optional[str] = None,
127
  thresholds: Optional[Dict[str, float]] = None,
128
  default_target_idx: int = 0,
129
  default_prediction_score: float = 0.0,
130
  show_as_markdown: bool = False,
131
  markdown_precision: int = 4,
132
+ bootstrap: Optional[list[str]] = None,
133
+ bootstrap_n: int = 1_000,
134
+ bootstrap_random_state: int | None = None,
135
+ bootstrap_alpha: float = 0.95,
136
+ create_plots: bool = True,
137
+ plots: Optional[Dict[str, str]] = None,
138
  ):
139
+ self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()}
140
  self.thresholds = thresholds or {}
141
  thresholds_not_in_metrics = {
142
  name: t for name, t in self.thresholds.items() if name not in self.metrics
 
145
  logger.warning(
146
  f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
147
  )
148
+ self.annotation_layer_name = layer
149
+ self.annotation_label = label
150
  self.default_target_idx = default_target_idx
151
  self.default_prediction_score = default_prediction_score
152
  self.show_as_markdown = show_as_markdown
153
  self.markdown_precision = markdown_precision
154
+ if create_plots:
155
+ self.plots = {
156
+ name: resolve_target(plot_func) for name, plot_func in (plots or {}).items()
157
+ }
158
+ else:
159
+ self.plots = {}
160
+
161
+ self.bootstrap = set(bootstrap or [])
162
+ self.bootstrap_kwargs = {
163
+ "n": bootstrap_n,
164
+ "random_state": bootstrap_random_state,
165
+ "alpha": bootstrap_alpha,
166
+ }
167
 
168
  super().__init__()
169
 
 
171
  self._preds: List[float] = []
172
  self._targets: List[int] = []
173
 
174
+ def _update(self, document: Document) -> None:
175
+ annotation_layer = document[self.annotation_layer_name]
176
+ target2idx = {
177
+ ann: int(ann.score)
178
+ for ann in annotation_layer
179
+ if self.annotation_label is None or ann.label == self.annotation_label
180
  }
181
+ prediction2score = {
182
+ ann: ann.score
183
+ for ann in annotation_layer.predictions
184
+ if self.annotation_label is None or ann.label == self.annotation_label
185
  }
186
+ all_args = set(target2idx) | set(prediction2score)
187
  all_targets: List[int] = []
188
  all_predictions: List[float] = []
189
  for args in all_args:
190
+ target_idx = target2idx.get(args, self.default_target_idx)
191
+ prediction_score = prediction2score.get(args, self.default_prediction_score)
192
  all_targets.append(target_idx)
193
  all_predictions.append(prediction_score)
194
+
 
 
195
  self._preds.extend(all_predictions)
196
  self._targets.extend(all_targets)
197
 
198
+ def create_plots(self):
 
199
 
200
  from matplotlib import pyplot as plt
201
 
202
  # Get the number of metrics
203
+ num_plots = len(self.plots)
204
 
205
  # Calculate rows and columns for subplots (aim for a square-like layout)
206
+ ncols = math.ceil(math.sqrt(num_plots))
207
+ nrows = math.ceil(num_plots / ncols)
208
 
209
  # Create the subplots
210
  fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
211
 
212
  # Flatten the ax_list if necessary (in case of multiple rows/columns)
213
+ if num_plots > 1:
214
+ ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
215
+ else:
216
+ ax_list = [ax_list]
217
 
218
+ # Create each plot
219
+ for ax, (name, plot_func) in zip(ax_list, self.plots.items()):
220
+ # Set the title for each subplot
221
+ ax.set_title(name)
222
+ plot_func(y_true=self._targets, y_pred=self._preds, ax=ax)
223
 
224
  # Adjust layout to avoid overlapping plots
225
  plt.tight_layout()
 
227
 
228
  def _compute(self) -> T:
229
 
230
+ if len(self.plots) > 0:
231
+ self.create_plots()
232
 
233
  result = {}
234
  for name, metric in self.metrics.items():
235
 
236
  if name in self.thresholds:
237
+ preds_dict = discretize(values=self._preds, threshold=self.thresholds[name])
238
+ if isinstance(preds_dict, dict):
239
+ metric_results = {
240
+ t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items()
241
+ }
242
+ # just get the max
243
+ max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
244
+ result[f"{name}_threshold"] = max_t
245
+ preds = discretize(values=self._preds, threshold=max_t)
246
+ else:
247
+ preds = preds_dict
248
  else:
249
  preds = self._preds
250
+
251
+ if name in self.bootstrap:
252
+ # bootstrap the metric
253
+ result[name] = bootstrap(
254
+ metric_fn=metric,
255
+ targets=self._targets,
256
+ predictions=preds,
257
+ **self.bootstrap_kwargs, # type: ignore
258
+ )
259
  else:
260
  result[name] = metric(self._targets, preds)
261
 
 
263
  if self.show_as_markdown:
264
  import pandas as pd
265
 
266
+ result_flat = flatten_dict(result)
267
+ series = pd.Series(result_flat)
268
  if isinstance(series.index, MultiIndex):
269
  if len(series.index.levels) > 1:
270
  # in fact, this is not a series anymore
src/metrics/f1_with_bootstrapping.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ from typing import Callable, Hashable, Optional, Tuple, Dict, Collection, List, Set
4
+
5
+ from pie_modules.metrics import F1Metric
6
+ from pytorch_ie import Annotation, Document
7
+
8
+ def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
9
+ return getattr(ann, label_field) in labels
10
+
11
+
12
+ def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
13
+ return getattr(ann, label_field) == label
14
+
15
+
16
+ class F1WithBootstrappingMetric(F1Metric):
17
+ def __init__(self, *args, bootstrap_n: int = 0, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.bootstrap_n = bootstrap_n
20
+
21
+
22
+ def reset(self) -> None:
23
+ self.tp: Dict[str, Set[Annotation]] = defaultdict(set)
24
+ self.fp: Dict[str, Set[Annotation]] = defaultdict(set)
25
+ self.fn: Dict[str, Set[Annotation]] = defaultdict(set)
26
+
27
+ def calculate_tp_fp_fn(
28
+ self,
29
+ document: Document,
30
+ annotation_filter: Optional[Callable[[Annotation], bool]] = None,
31
+ annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
32
+ ) -> Tuple[Set[Annotation], Set[Annotation], Set[Annotation]]:
33
+ annotation_processor = annotation_processor or (lambda ann: ann)
34
+ annotation_filter = annotation_filter or (lambda ann: True)
35
+ predicted_annotations = {
36
+ annotation_processor(ann)
37
+ for ann in document[self.layer].predictions
38
+ if annotation_filter(ann)
39
+ }
40
+ gold_annotations = {
41
+ annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
42
+ }
43
+ return predicted_annotations & gold_annotations, predicted_annotations - gold_annotations, gold_annotations - predicted_annotations
44
+
45
+
46
+ def add_tp_fp_fn(self, tp: Set[Annotation], fp: Set[Annotation], fn: Set[Annotation], label: str) -> None:
47
+ self.tp[label].update(tp)
48
+ self.fp[label].update(fp)
49
+ self.fn[label].update(fn)
50
+
51
+ def _update(self, document: Document) -> None:
52
+ new_values = self.calculate_tp_fp_fn(
53
+ document=document,
54
+ annotation_filter=(
55
+ partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
56
+ if self.per_label and not self.infer_labels
57
+ else None
58
+ ),
59
+ annotation_processor=self.annotation_processor,
60
+ )
61
+ self.add_tp_fp_fn(*new_values, label="MICRO")
62
+ if self.infer_labels:
63
+ layer = document[self.layer]
64
+ # collect labels from gold data and predictions
65
+ for ann in list(layer) + list(layer.predictions):
66
+ label = getattr(ann, self.label_field)
67
+ if label not in self.labels:
68
+ self.labels.append(label)
69
+ if self.per_label:
70
+ for label in self.labels:
71
+ new_values = self.calculate_tp_fp_fn(
72
+ document=document,
73
+ annotation_filter=partial(
74
+ has_this_label, label_field=self.label_field, label=label
75
+ ),
76
+ annotation_processor=self.annotation_processor,
77
+ )
78
+ self.add_tp_fp_fn(*new_values, label=label)
79
+
80
+ def _compute(self) -> Dict[str, Dict[str, float]]:
81
+ res = dict()
82
+ if self.per_label:
83
+ res["MACRO"] = {"f1": 0.0, "p": 0.0, "r": 0.0}
84
+ for label in self.tp.keys():
85
+ tp, fp, fn = (
86
+ len(self.tp[label]),
87
+ len(self.fp[label]),
88
+ len(self.fn[label]),
89
+ )
90
+ if tp == 0:
91
+ p, r, f1 = 0.0, 0.0, 0.0
92
+ else:
93
+ p = tp / (tp + fp)
94
+ r = tp / (tp + fn)
95
+ f1 = 2 * p * r / (p + r)
96
+ res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn}
97
+ if self.per_label and label in self.labels:
98
+ res["MACRO"]["f1"] += f1 / len(self.labels)
99
+ res["MACRO"]["p"] += p / len(self.labels)
100
+ res["MACRO"]["r"] += r / len(self.labels)
101
+ if self.show_as_markdown:
102
+ logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
103
+ return res
src/metrics/f1_with_threshold.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Hashable, Optional, Tuple
2
+
3
+ from pie_modules.metrics import F1Metric
4
+ from pytorch_ie import Annotation, Document
5
+
6
+
7
+ class F1WithThresholdMetric(F1Metric):
8
+ def __init__(self, *args, threshold: float = 0.0, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.threshold = threshold
11
+
12
+ def calculate_counts(
13
+ self,
14
+ document: Document,
15
+ annotation_filter: Optional[Callable[[Annotation], bool]] = None,
16
+ annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
17
+ ) -> Tuple[int, int, int]:
18
+ annotation_processor = annotation_processor or (lambda ann: ann)
19
+ annotation_filter = annotation_filter or (lambda ann: True)
20
+ predicted_annotations = {
21
+ annotation_processor(ann)
22
+ for ann in document[self.layer].predictions
23
+ if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
24
+ }
25
+ gold_annotations = {
26
+ annotation_processor(ann)
27
+ for ann in document[self.layer]
28
+ if annotation_filter(ann) and getattr(ann, "score", 0.0) >= self.threshold
29
+ }
30
+ tp = len([ann for ann in predicted_annotations & gold_annotations])
31
+ fn = len([ann for ann in gold_annotations - predicted_annotations])
32
+ fp = len([ann for ann in predicted_annotations - gold_annotations])
33
+ return tp, fp, fn
src/metrics/ranking_sklearn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Callable, Dict, List, Optional, Sequence, Union
4
+
5
+ from pandas import MultiIndex
6
+ from pytorch_ie import Annotation, AnnotationLayer, Document, DocumentMetric
7
+ from pytorch_ie.annotations import BinaryRelation
8
+ from pytorch_ie.core.metric import T
9
+ from pytorch_ie.utils.hydra import resolve_target
10
+
11
+ from src.hydra_callbacks.save_job_return_value import to_py_obj
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RankingMetricsSKLearn(DocumentMetric):
17
+ """Ranking metrics for documents with binary relations.
18
+
19
+ This metric computes the ranking metrics for retrieval tasks, where
20
+ relation heads are the queries and the relation tails are the candidates.
21
+ The metric is computed for each head and the results are averaged. It is meant to
22
+ be used with Scikit-learn metrics such as `sklearn.metrics.ndcg_score` (Normalized
23
+ Discounted Cumulative Gain), `sklearn.metrics.label_ranking_average_precision_score`
24
+ (LRAP), etc., see
25
+ https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics.
26
+
27
+ Args:
28
+ metrics (Dict[str, Union[str, Callable]]): A dictionary of metric names and their
29
+ corresponding functions. The function can be a string (name of the function, e.g.,
30
+ sklearn.metrics.ndcg_score) or a callable.
31
+ layer (str): The name of the annotation layer containing the binary relations, e.g.,
32
+ "binary_relations" when applied to TextDocumentsWithLabeledSpansAndBinaryRelations.
33
+ use_manual_average (Optional[List[str]]): A list of metric names to use for manual
34
+ averaging. If provided, the metric scores will be calculated for each
35
+ head and then averaged. Otherwise, all true and predicted scores will be
36
+ passed to the metric function at once.
37
+ exclude_singletons (Optional[List[str]]): A list of metric names to exclude singletons
38
+ from the computation, i.e., entries (heads) where the number of candidates is 1.
39
+ label (Optional[str]): If provided, only the relations with this label will be used
40
+ to compute the metrics. This is useful for filtering out relations that are not
41
+ relevant for the task at hand (e.g., when having multiple relation types in the
42
+ same layer).
43
+ score_threshold (float): If provided, only the relations with a score greater than or
44
+ equal to this threshold will be used to compute the metrics.
45
+ default_score (float): The default score to use for missing relations, either in the
46
+ target or prediction. Default is 0.0.
47
+ use_all_spans (bool): Whether to consider all spans in the document as queries and
48
+ candidates or only the spans that are present in the target and prediction.
49
+ span_label_blacklist (Optional[List[str]]): If provided, ignore the relations with
50
+ heads/tails that are in this list. When using use_all_spans=True, this also
51
+ restricts the candidates to those that are not in the blacklist.
52
+ show_as_markdown (bool): Whether to show the results as markdown. Default is False.
53
+ markdown_precision (int): The precision for displaying the results in markdown.
54
+ Default is 4.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ metrics: Dict[str, Union[str, Callable]],
60
+ layer: str,
61
+ use_manual_average: Optional[List[str]] = None,
62
+ exclude_singletons: Optional[List[str]] = None,
63
+ label: Optional[str] = None,
64
+ score_threshold: float = 0.0,
65
+ default_score: float = 0.0,
66
+ use_all_spans: bool = False,
67
+ span_label_blacklist: Optional[List[str]] = None,
68
+ show_as_markdown: bool = False,
69
+ markdown_precision: int = 4,
70
+ plot: bool = False,
71
+ ):
72
+ self.metrics = {
73
+ name: resolve_target(metric) if isinstance(metric, str) else metric
74
+ for name, metric in metrics.items()
75
+ }
76
+ self.use_manual_average = set(use_manual_average or [])
77
+ self.exclude_singletons = set(exclude_singletons or [])
78
+ self.annotation_layer_name = layer
79
+ self.annotation_label = label
80
+ self.score_threshold = score_threshold
81
+ self.default_score = default_score
82
+ self.use_all_spans = use_all_spans
83
+ self.span_label_blacklist = span_label_blacklist
84
+ self.show_as_markdown = show_as_markdown
85
+ self.markdown_precision = markdown_precision
86
+ self.plot = plot
87
+
88
+ super().__init__()
89
+
90
+ def reset(self) -> None:
91
+ self._preds: List[List[float]] = []
92
+ self._targets: List[List[float]] = []
93
+
94
+ def get_head2tail2score(
95
+ self, relations: Sequence[BinaryRelation]
96
+ ) -> Dict[Annotation, Dict[Annotation, float]]:
97
+ result: Dict[Annotation, Dict[Annotation, float]] = defaultdict(dict)
98
+ for rel in relations:
99
+ if (
100
+ (self.annotation_label is None or rel.label == self.annotation_label)
101
+ and (rel.score >= self.score_threshold)
102
+ and (
103
+ self.span_label_blacklist is None
104
+ or (
105
+ rel.head.label not in self.span_label_blacklist
106
+ and rel.tail.label not in self.span_label_blacklist
107
+ )
108
+ )
109
+ ):
110
+ result[rel.head][rel.tail] = rel.score
111
+
112
+ return result
113
+
114
+ def _update(self, document: Document) -> None:
115
+ annotation_layer: AnnotationLayer[BinaryRelation] = document[self.annotation_layer_name]
116
+
117
+ target_head2tail2score = self.get_head2tail2score(annotation_layer)
118
+ prediction_head2tail2score = self.get_head2tail2score(annotation_layer.predictions)
119
+ all_spans = set()
120
+ # get spans from all layers targeted by the annotation (relation) layer
121
+ for span_layer in annotation_layer.target_layers.values():
122
+ all_spans.update(span_layer)
123
+
124
+ if self.span_label_blacklist is not None:
125
+ all_spans = {span for span in all_spans if span.label not in self.span_label_blacklist}
126
+
127
+ if self.use_all_spans:
128
+ all_heads = all_spans
129
+ else:
130
+ all_heads = set(target_head2tail2score) | set(prediction_head2tail2score)
131
+
132
+ all_targets: List[List[float]] = []
133
+ all_predictions: List[List[float]] = []
134
+ for head in all_heads:
135
+ target_tail2score = target_head2tail2score.get(head, {})
136
+ prediction_tail2score = prediction_head2tail2score.get(head, {})
137
+ if self.use_all_spans:
138
+ # use all spans as tails
139
+ tails = set(span for span in all_spans if span != head)
140
+ else:
141
+ # use only the tails that are in the target or prediction
142
+ tails = set(target_tail2score) | set(prediction_tail2score)
143
+ target_scores = [target_tail2score.get(t, self.default_score) for t in tails]
144
+ prediction_scores = [prediction_tail2score.get(t, self.default_score) for t in tails]
145
+ all_targets.append(target_scores)
146
+ all_predictions.append(prediction_scores)
147
+
148
+ self._targets.extend(all_targets)
149
+ self._preds.extend(all_predictions)
150
+
151
+ def do_plot(self):
152
+ raise NotImplementedError()
153
+
154
+ def _compute(self) -> T:
155
+
156
+ if self.plot:
157
+ self.do_plot()
158
+
159
+ result = {}
160
+ for name, metric in self.metrics.items():
161
+ targets, preds = self._targets, self._preds
162
+ if name in self.exclude_singletons:
163
+ targets = [t for t in targets if len(t) > 1]
164
+ preds = [p for p in preds if len(p) > 1]
165
+ num_singletons = len(self._targets) - len(targets)
166
+ logger.warning(
167
+ f"Excluding {num_singletons} singletons (out of {len(self._targets)} "
168
+ f"entries) from {name} metric calculation."
169
+ )
170
+
171
+ if name in self.use_manual_average:
172
+ scores = [
173
+ metric(y_true=[tgts], y_score=[prds]) for tgts, prds in zip(targets, preds)
174
+ ]
175
+ result[name] = sum(scores) / len(scores) if len(scores) > 0 else 0.0
176
+ else:
177
+ result[name] = metric(y_true=targets, y_score=preds)
178
+
179
+ result = to_py_obj(result)
180
+ if self.show_as_markdown:
181
+ import pandas as pd
182
+
183
+ series = pd.Series(result)
184
+ if isinstance(series.index, MultiIndex):
185
+ if len(series.index.levels) > 1:
186
+ # in fact, this is not a series anymore
187
+ series = series.unstack(-1)
188
+ else:
189
+ series.index = series.index.get_level_values(0)
190
+ logger.info(
191
+ f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
192
+ )
193
+ return result
src/metrics/score_distribution.py CHANGED
@@ -1,9 +1,12 @@
 
1
  from collections import defaultdict
2
  from typing import Any, Dict, List, Optional, Tuple
3
 
4
  import pandas as pd
5
  from pytorch_ie import Document, DocumentMetric
6
 
 
 
7
 
8
  class ScoreDistribution(DocumentMetric):
9
  """Computes the distribution of prediction scores for annotations in a layer. The scores are
@@ -36,7 +39,8 @@ class ScoreDistribution(DocumentMetric):
36
  plotly_use_create_distplot: bool = True,
37
  plotly_barmode: Optional[str] = None,
38
  plotly_marginal: Optional[str] = "violin",
39
- plotly_font_size: int = 18,
 
40
  plotly_font_family: Optional[str] = None,
41
  plotly_background_color: Optional[str] = None,
42
  ):
@@ -52,7 +56,12 @@ class ScoreDistribution(DocumentMetric):
52
  self.plotly_use_create_distplot = plotly_use_create_distplot
53
  self.plotly_barmode = plotly_barmode
54
  self.plotly_marginal = plotly_marginal
55
- self.plotly_font_size = plotly_font_size
 
 
 
 
 
56
  self.plotly_font_family = plotly_font_family
57
  self.plotly_background_color = plotly_background_color
58
  self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
@@ -231,7 +240,7 @@ class ScoreDistribution(DocumentMetric):
231
  width=800,
232
  title_text=description,
233
  title_x=0.5,
234
- font=dict(size=self.plotly_font_size),
235
  legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
236
  )
237
  if self.plotly_barmode is not None:
@@ -290,7 +299,7 @@ class ScoreDistribution(DocumentMetric):
290
  width=800,
291
  title_text=f"Mean Binned Scores for {self.mapped_layer}",
292
  title_x=0.5,
293
- font=dict(size=self.plotly_font_size),
294
  )
295
  fig.update_layout(
296
  legend=dict(
 
1
+ import logging
2
  from collections import defaultdict
3
  from typing import Any, Dict, List, Optional, Tuple
4
 
5
  import pandas as pd
6
  from pytorch_ie import Document, DocumentMetric
7
 
8
+ logger = logging.getLogger()
9
+
10
 
11
  class ScoreDistribution(DocumentMetric):
12
  """Computes the distribution of prediction scores for annotations in a layer. The scores are
 
39
  plotly_use_create_distplot: bool = True,
40
  plotly_barmode: Optional[str] = None,
41
  plotly_marginal: Optional[str] = "violin",
42
+ plotly_font: Optional[Dict[str, Any]] = None,
43
+ plotly_font_size: Optional[int] = None,
44
  plotly_font_family: Optional[str] = None,
45
  plotly_background_color: Optional[str] = None,
46
  ):
 
56
  self.plotly_use_create_distplot = plotly_use_create_distplot
57
  self.plotly_barmode = plotly_barmode
58
  self.plotly_marginal = plotly_marginal
59
+ self.plotly_font = plotly_font or {}
60
+ if plotly_font_size is not None:
61
+ logger.warning(
62
+ "Parameter 'plotly_font_size' is deprecated. Use 'plotly_font' with 'size' key instead."
63
+ )
64
+ self.plotly_font["size"] = plotly_font_size
65
  self.plotly_font_family = plotly_font_family
66
  self.plotly_background_color = plotly_background_color
67
  self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
 
240
  width=800,
241
  title_text=description,
242
  title_x=0.5,
243
+ font=self.plotly_font,
244
  legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
245
  )
246
  if self.plotly_barmode is not None:
 
299
  width=800,
300
  title_text=f"Mean Binned Scores for {self.mapped_layer}",
301
  title_x=0.5,
302
+ font=self.plotly_font,
303
  )
304
  fig.update_layout(
305
  legend=dict(
src/metrics/semantically_same_ranking.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from collections import defaultdict
4
+ from functools import partial
5
+ from typing import Callable, Iterable, List, Optional, Set, Tuple
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from pytorch_ie import DocumentMetric
10
+ from pytorch_ie.annotations import BinaryRelation
11
+ from sklearn.metrics import average_precision_score, ndcg_score
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ NEG_INF = -1e9 # smaller than any real score
16
+
17
+ # metrics
18
+
19
+
20
+ def true_mrr(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float:
21
+ """
22
+ Macro MRR over *all* queries.
23
+ • Reciprocal rank is 0 when a query has no relevant item.
24
+ • If k is given, restrict the search to the top-k list.
25
+ """
26
+ if y_true.size == 0:
27
+ return np.nan
28
+
29
+ rr = []
30
+ for t, s in zip(y_true, y_score):
31
+ if t.sum() == 0:
32
+ rr.append(0.0)
33
+ continue
34
+
35
+ order = np.argsort(-s)
36
+ if k is not None:
37
+ order = order[:k]
38
+
39
+ # first position where t == 1, +1 for 1-based rank
40
+ first_hit = np.flatnonzero(t[order] > 0)
41
+ rank = first_hit[0] + 1 if first_hit.size else np.inf
42
+ rr.append(0.0 if np.isinf(rank) else 1.0 / rank)
43
+
44
+ return np.mean(rr)
45
+
46
+
47
+ def macro_ndcg(y_true: np.ndarray, y_score: np.ndarray, k: int | None = None) -> float:
48
+ """
49
+ Macro NDCG@k over all queries.
50
+
51
+ ndcg_score returns 0 when a query has no positives, so no masking is required.
52
+ """
53
+ if y_true.size == 0:
54
+ return np.nan
55
+ return ndcg_score(y_true, y_score, k=k)
56
+
57
+
58
+ def macro_map(y_true: np.ndarray, y_score: np.ndarray) -> float:
59
+ """
60
+ Macro MAP: mean of Average-Precision per query.
61
+ Queries without positives contribute AP = 0.
62
+ """
63
+ if y_true.size == 0:
64
+ return np.nan
65
+
66
+ ap = []
67
+ for t, s in zip(y_true, y_score):
68
+ if t.sum() == 0:
69
+ ap.append(0.0)
70
+ else:
71
+ ap.append(average_precision_score(t, s))
72
+ return np.mean(ap)
73
+
74
+
75
+ def ap_micro(y_true: np.ndarray, y_score: np.ndarray) -> float:
76
+ """
77
+ Micro AP over the entire pool (unchanged).
78
+ """
79
+ with warnings.catch_warnings():
80
+ warnings.filterwarnings("ignore", message="No positive class found in y_true")
81
+ return average_precision_score(y_true.ravel(), y_score.ravel())
82
+
83
+
84
+ # ---------------------------
85
+ # Recall@k
86
+ # ---------------------------
87
+
88
+
89
+ def recall_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
90
+ """
91
+ Micro Recall@k (a.k.a. instance-level recall)
92
+
93
+ – Each *positive instance* counts once, regardless of which query it belongs to.
94
+ – Denominator = total #positives across the whole pool.
95
+ """
96
+ total_pos = y_true.sum()
97
+ if total_pos == 0:
98
+ return np.nan
99
+
100
+ topk = np.argsort(-y_score, axis=1)[:, :k] # indices of top-k per query
101
+ rows = np.arange(topk.shape[0])[:, None]
102
+
103
+ hits = (y_true[rows, topk] > 0).sum() # total #hits (instances)
104
+ return hits / total_pos
105
+
106
+
107
+ def recall_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
108
+ """
109
+ Macro Recall@k (query-level recall)
110
+
111
+ – First compute recall per *query* (#hits / #positives in that query).
112
+ – Then average across all queries that actually contain ≥1 positive.
113
+ """
114
+ mask = y_true.sum(axis=1) > 0 # keep only valid queries
115
+ if not mask.any():
116
+ return np.nan
117
+
118
+ Yt, Ys = y_true[mask], y_score[mask]
119
+ topk = np.argsort(-Ys, axis=1)[:, :k]
120
+ rows = np.arange(Yt.shape[0])[:, None]
121
+
122
+ hits_per_q = (Yt[rows, topk] > 0).sum(axis=1) # shape: (n_queries,)
123
+ pos_per_q = Yt.sum(axis=1)
124
+
125
+ return np.mean(hits_per_q / pos_per_q) # average of query recalls
126
+
127
+
128
+ # ---------------------------
129
+ # Precision@k
130
+ # ---------------------------
131
+
132
+
133
+ def precision_at_k_micro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
134
+ """
135
+ Micro Precision@k (pool-level precision)
136
+
137
+ – Numerator = total #hits across all queries.
138
+ – Denominator = total #predictions considered (n_queries · k).
139
+ """
140
+ if y_true.size == 0:
141
+ return np.nan
142
+
143
+ topk = np.argsort(-y_score, axis=1)[:, :k]
144
+ rows = np.arange(topk.shape[0])[:, None]
145
+
146
+ hits = (y_true[rows, topk] > 0).sum()
147
+ total_pred = y_true.shape[0] * k
148
+ return hits / total_pred
149
+
150
+
151
+ def precision_at_k_macro(y_true: np.ndarray, y_score: np.ndarray, k: int = 5) -> float:
152
+ """
153
+ Macro Precision@k (query-level precision)
154
+
155
+ – Compute precision = (#hits / k) for each query, **including those with zero positives**,
156
+ then average.
157
+ """
158
+ if y_true.size == 0:
159
+ return np.nan
160
+
161
+ topk = np.argsort(-y_score, axis=1)[:, :k]
162
+ rows = np.arange(topk.shape[0])[:, None]
163
+
164
+ rel = y_true[rows, topk] > 0 # shape: (n_queries, k)
165
+ precision_per_q = rel.mean(axis=1) # mean over k positions
166
+ return precision_per_q.mean()
167
+
168
+
169
+ # helper methods
170
+
171
+
172
+ def bootstrap(
173
+ metric_fn: Callable[[np.ndarray, np.ndarray], float],
174
+ y_true: np.ndarray,
175
+ y_score: np.ndarray,
176
+ n: int = 1000,
177
+ rng=None,
178
+ ) -> dict[str, float]:
179
+ rng = np.random.default_rng(rng)
180
+ idx = np.arange(len(y_true))
181
+ vals: list[float] = []
182
+
183
+ while len(vals) < n:
184
+ sample = rng.choice(idx, size=len(idx), replace=True)
185
+ t = y_true[sample]
186
+ s = y_score[sample]
187
+ if t.sum() == 0: # no positive at all → resample
188
+ continue
189
+ vals.append(metric_fn(t, s))
190
+
191
+ result = np.asarray(vals)
192
+ # get 95% confidence interval
193
+ lo, hi = np.percentile(result, [2.5, 97.5])
194
+ return {"mean": result.mean(), "low": lo, "high": hi}
195
+
196
+
197
+ def evaluate_with_ranx(
198
+ pred_rels: set[BinaryRelation],
199
+ target_rels: set[BinaryRelation],
200
+ metrics: list[str],
201
+ include_queries_without_gold: bool = True,
202
+ ) -> dict[str, float]:
203
+
204
+ # lazy import to not require ranx via requirements.txt
205
+ import ranx
206
+
207
+ all_rels = set(pred_rels) | set(target_rels)
208
+ all_heads = {rel.head for rel in all_rels}
209
+ head2id = {head: f"q_{idx}" for idx, head in enumerate(sorted(all_heads))}
210
+ tail_and_label2id = {(ann.tail, ann.label): f"d_{idx}" for idx, ann in enumerate(all_rels)}
211
+
212
+ qrels_dict: dict[str, dict[str, int]] = defaultdict(dict) # {query_id: {doc_id: 1}}
213
+ run_dict: dict[str, dict[str, float]] = defaultdict(dict) # {query_id: {doc_id: score}}
214
+
215
+ for target_rel in target_rels:
216
+ query_id = head2id[target_rel.head]
217
+ doc_id = tail_and_label2id[(target_rel.tail, target_rel.label)]
218
+ if target_rel.score != 1.0:
219
+ raise ValueError(
220
+ f"target score must be 1.0, but got {target_rel.score} for {target_rel}"
221
+ )
222
+ qrels_dict[query_id][doc_id] = 1
223
+
224
+ for pred_rel in pred_rels:
225
+ query_id = head2id[pred_rel.head]
226
+ doc_id = tail_and_label2id[(pred_rel.tail, pred_rel.label)]
227
+ run_dict[query_id][doc_id] = pred_rel.score
228
+
229
+ if include_queries_without_gold:
230
+ # add missing query ids to rund_dict and qrels_dict
231
+ for query_id in set(head2id.values()) - set(qrels_dict):
232
+ qrels_dict[query_id] = {}
233
+
234
+ # evaluate
235
+ qrels = ranx.Qrels(qrels_dict)
236
+ run = ranx.Run(run_dict)
237
+ results = ranx.evaluate(qrels, run, metrics, make_comparable=True)
238
+ return results
239
+
240
+
241
+ def deduplicate_relations(
242
+ relations: Iterable[BinaryRelation], caption: str
243
+ ) -> Set[BinaryRelation]:
244
+ pred2scores = defaultdict(set)
245
+ for ann in relations:
246
+ pred2scores[ann].add(round(ann.score, 4))
247
+ # warning for duplicates
248
+ preds_with_duplicates = [ann for ann, scores in pred2scores.items() if len(scores) > 1]
249
+ if len(preds_with_duplicates) > 0:
250
+ logger.warning(
251
+ f"there are {len(preds_with_duplicates)} {caption} with duplicates: "
252
+ f"{preds_with_duplicates}. We will take the max score for each annotation."
253
+ )
254
+
255
+ # take the max score for each annotation
256
+ result = {ann.copy(score=max(scores)) for ann, scores in pred2scores.items()}
257
+ return result
258
+
259
+
260
+ def construct_y_true_and_score(
261
+ preds: Iterable[BinaryRelation], targets: Iterable[BinaryRelation]
262
+ ) -> Tuple[np.ndarray, np.ndarray]:
263
+
264
+ # helper constructs
265
+ all_anns = set(preds) | set(targets)
266
+ head2relations = defaultdict(list)
267
+ for ann in all_anns:
268
+ head2relations[ann.head].append(ann)
269
+ target2score = {rel: rel.score for rel in targets}
270
+ pred2score = {rel: rel.score for rel in preds}
271
+
272
+ max_len = max(len(relations) for relations in head2relations.values())
273
+ target_rows, pred_rows = [], []
274
+ for query in head2relations:
275
+ relations = head2relations[query]
276
+ # get a very small, random score for missing predictions. Or should we use 0.0 as before? or NEG_INF?
277
+ missing_pred_score = NEG_INF # np.random.uniform(0.0, 0.001) #0.0 #
278
+ missing_target_score = 0
279
+ query_scores = [
280
+ (target2score.get(ann, missing_target_score), pred2score.get(ann, missing_pred_score))
281
+ for ann in relations
282
+ ]
283
+
284
+ # sort by descending order of prediction score
285
+ query_scores_sorted = np.array(sorted(query_scores, key=lambda x: x[1], reverse=True))
286
+
287
+ # pad with zeros so every row has the same length
288
+ pad_width = max_len - len(query_scores)
289
+ query_target = np.pad(
290
+ query_scores_sorted[:, 0], (0, pad_width), constant_values=missing_target_score
291
+ )
292
+ query_pred = np.pad(
293
+ query_scores_sorted[:, 1], (0, pad_width), constant_values=missing_pred_score
294
+ )
295
+
296
+ target_rows.append(query_target)
297
+ pred_rows.append(query_pred)
298
+
299
+ y_true = np.vstack(target_rows) # shape (n_queries, max_len)
300
+ y_score = np.vstack(pred_rows)
301
+
302
+ return y_true, y_score
303
+
304
+
305
+ class SemanticallySameRankingMetric(DocumentMetric):
306
+
307
+ def __init__(
308
+ self,
309
+ layer: str,
310
+ label: Optional[str] = None,
311
+ add_reversed: bool = False,
312
+ require_positive_gold: bool = False,
313
+ bootstrap_n: Optional[int] = None,
314
+ k_values: Optional[List[int]] = None,
315
+ return_coverage: bool = True,
316
+ show_as_markdown: bool = False,
317
+ use_ranx: bool = False,
318
+ add_stats_to_result: bool = False,
319
+ ) -> None:
320
+ super().__init__()
321
+ self.layer = layer
322
+ self.label = label
323
+ self.add_reversed = add_reversed
324
+ self.require_positive_gold = require_positive_gold
325
+ self.bootstrap_n = bootstrap_n
326
+ self.k_values = k_values if k_values is not None else [1, 5, 10]
327
+ self.return_coverage = return_coverage
328
+ self.show_as_markdown = show_as_markdown
329
+ self.use_ranx = use_ranx
330
+ self.add_stats_to_result = add_stats_to_result
331
+
332
+ self.metrics = {
333
+ "macro_ndcg": macro_ndcg,
334
+ "macro_mrr": true_mrr,
335
+ "macro_map": macro_map,
336
+ "micro_ap": ap_micro,
337
+ }
338
+ for name, func in [
339
+ ("macro_ndcg", macro_ndcg),
340
+ ("micro_recall", recall_at_k_micro),
341
+ ("micro_precision", precision_at_k_micro),
342
+ ("macro_recall", recall_at_k_macro),
343
+ ("macro_precision", precision_at_k_macro),
344
+ ]:
345
+ for k in self.k_values:
346
+ self.metrics[f"{name}@{k}"] = partial(func, k=k) # type: ignore
347
+
348
+ self.ranx_metrics = ["map", "mrr", "ndcg"]
349
+ for name in ["recall", "precision", "ndcg"]:
350
+ for k in self.k_values:
351
+ self.ranx_metrics.append(f"{name}@{k}")
352
+
353
+ def reset(self) -> None:
354
+ """
355
+ Reset the metric to its initial state.
356
+ """
357
+ self._preds: List[BinaryRelation] = []
358
+ self._targets: List[BinaryRelation] = []
359
+
360
+ def _update(self, document):
361
+ layer = document[self.layer]
362
+ ann: BinaryRelation
363
+ for ann in layer:
364
+ if self.label is None or ann.label == self.label:
365
+ if ann.score > 0.0:
366
+ self._targets.append(ann.copy())
367
+ if self.add_reversed:
368
+ self._targets.append(ann.copy(head=ann.tail, tail=ann.head))
369
+
370
+ for ann in layer.predictions:
371
+ if self.label is None or ann.label == self.label:
372
+ if ann.score > 0.0:
373
+ self._preds.append(ann.copy())
374
+ if self.add_reversed:
375
+ self._preds.append(ann.copy(head=ann.tail, tail=ann.head))
376
+
377
+ def _compute(self):
378
+ # take the max score for each annotation
379
+ preds_deduplicated = deduplicate_relations(self._preds, "predictions")
380
+ targets_deduplicated = deduplicate_relations(self._targets, "targets")
381
+
382
+ stats = {
383
+ "gold": len(targets_deduplicated),
384
+ "preds": len(preds_deduplicated),
385
+ "queries": len(
386
+ set(ann.head for ann in targets_deduplicated)
387
+ | set(ann.head for ann in preds_deduplicated)
388
+ ),
389
+ }
390
+
391
+ if self.use_ranx:
392
+ if self.bootstrap_n is not None:
393
+ raise ValueError(
394
+ "Ranx does not support bootstrapping. Please set bootstrap_n=None."
395
+ )
396
+
397
+ scores = evaluate_with_ranx(
398
+ preds_deduplicated,
399
+ targets_deduplicated,
400
+ metrics=self.ranx_metrics,
401
+ include_queries_without_gold=not self.require_positive_gold,
402
+ )
403
+ if self.add_stats_to_result:
404
+ scores.update(stats)
405
+ # logger.info(f"results via ranx:\n{pd.Series(ranx_result).sort_index().round(3).to_markdown()}")
406
+ df = pd.DataFrame.from_records([scores], index=["score"])
407
+ else:
408
+
409
+ y_true, y_score = construct_y_true_and_score(
410
+ preds=preds_deduplicated, targets=targets_deduplicated
411
+ )
412
+
413
+ # original definition ─ share of queries with ≥1 positive
414
+ coverage = (y_true.sum(axis=1) > 0).mean()
415
+
416
+ # keep only queries that actually have at least one gold positive
417
+ if self.require_positive_gold:
418
+ mask = y_true.sum(axis=1) > 0 # shape: (n_queries,)
419
+ y_true = y_true[mask]
420
+ y_score = y_score[mask]
421
+
422
+ if self.bootstrap_n is not None:
423
+ scores = {
424
+ name: bootstrap(fn, y_true, y_score, n=self.bootstrap_n)
425
+ for name, fn in self.metrics.items()
426
+ }
427
+ if self.add_stats_to_result:
428
+ scores["stats"] = stats
429
+ df = pd.DataFrame(scores)
430
+ else:
431
+ scores = {name: fn(y_true, y_score) for name, fn in self.metrics.items()}
432
+ if self.add_stats_to_result:
433
+ scores.update(stats)
434
+ df = pd.DataFrame.from_records([scores], index=["score"])
435
+
436
+ if self.return_coverage:
437
+ scores["coverage"] = coverage
438
+
439
+ if self.show_as_markdown:
440
+ if not self.add_stats_to_result:
441
+ logger.info(
442
+ logger.info(
443
+ f'\nstatistics ({self.layer}):\n{pd.Series(stats, name="value").to_markdown()}'
444
+ )
445
+ )
446
+ logger.info(f"\n{self.layer}:\n{df.round(4).T.to_markdown()}")
447
+
448
+ return scores
src/metrics/tpfpfn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from functools import partial
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Collection,
8
+ Dict,
9
+ Hashable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ TypeAlias,
14
+ Union,
15
+ )
16
+
17
+ from pytorch_ie.core import Annotation, Document, DocumentMetric
18
+ from pytorch_ie.utils.hydra import resolve_target
19
+
20
+ from src.document.types import RelatedRelation
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def has_one_of_the_labels(ann: Annotation, label_field: str, labels: Collection[str]) -> bool:
26
+ return getattr(ann, label_field) in labels
27
+
28
+
29
+ def has_this_label(ann: Annotation, label_field: str, label: str) -> bool:
30
+ return getattr(ann, label_field) == label
31
+
32
+
33
+ InstanceType: TypeAlias = Tuple[Document, Annotation]
34
+ InstancesType: TypeAlias = Tuple[List[InstanceType], List[InstanceType], List[InstanceType]]
35
+
36
+
37
+ class TPFFPFNMetric(DocumentMetric):
38
+ """Computes the lists of True Positive, False Positive, and False Negative
39
+ annotations for a given layer. If labels are provided, it also computes
40
+ the counts for each label separately.
41
+
42
+ Works only with `RelatedRelation` annotations for now.
43
+
44
+ Args:
45
+ layer: The layer to compute the metrics for.
46
+ labels: If provided, calculate metrics for each label.
47
+ label_field: The field to use for the label. Defaults to "label".
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ layer: str,
53
+ labels: Optional[Union[Collection[str], str]] = None,
54
+ label_field: str = "label",
55
+ annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None,
56
+ ):
57
+ super().__init__()
58
+ self.layer = layer
59
+ self.label_field = label_field
60
+ self.annotation_processor: Optional[Callable[[Annotation], Hashable]]
61
+ if isinstance(annotation_processor, str):
62
+ self.annotation_processor = resolve_target(annotation_processor)
63
+ else:
64
+ self.annotation_processor = annotation_processor
65
+
66
+ self.per_label = labels is not None
67
+ self.infer_labels = False
68
+ if self.per_label:
69
+ if isinstance(labels, str):
70
+ if labels != "INFERRED":
71
+ raise ValueError(
72
+ "labels can only be 'INFERRED' if per_label is True and labels is a string"
73
+ )
74
+ self.labels = []
75
+ self.infer_labels = True
76
+ elif isinstance(labels, Collection):
77
+ if not all(isinstance(label, str) for label in labels):
78
+ raise ValueError("labels must be a collection of strings")
79
+ if "MICRO" in labels or "MACRO" in labels:
80
+ raise ValueError(
81
+ "labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
82
+ )
83
+ if len(labels) == 0:
84
+ raise ValueError("labels cannot be empty")
85
+ self.labels = list(labels)
86
+ else:
87
+ raise ValueError("labels must be a string or a collection of strings")
88
+
89
+ def reset(self):
90
+ self.tp_fp_fn = defaultdict(lambda: (list(), list(), list()))
91
+
92
+ def get_tp_fp_fn(
93
+ self,
94
+ document: Document,
95
+ annotation_filter: Optional[Callable[[Annotation], bool]] = None,
96
+ annotation_processor: Optional[Callable[[Annotation], Hashable]] = None,
97
+ ) -> InstancesType:
98
+ annotation_processor = annotation_processor or (lambda ann: ann)
99
+ annotation_filter = annotation_filter or (lambda ann: True)
100
+ predicted_annotations = {
101
+ annotation_processor(ann)
102
+ for ann in document[self.layer].predictions
103
+ if annotation_filter(ann)
104
+ }
105
+ gold_annotations = {
106
+ annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
107
+ }
108
+ tp = [(document, ann) for ann in predicted_annotations & gold_annotations]
109
+ fn = [(document, ann) for ann in gold_annotations - predicted_annotations]
110
+ fp = [(document, ann) for ann in predicted_annotations - gold_annotations]
111
+ return tp, fp, fn
112
+
113
+ def add_annotations(self, annotations: InstancesType, label: str):
114
+ self.tp_fp_fn[label] = (
115
+ self.tp_fp_fn[label][0] + annotations[0],
116
+ self.tp_fp_fn[label][1] + annotations[1],
117
+ self.tp_fp_fn[label][2] + annotations[2],
118
+ )
119
+
120
+ def _update(self, document: Document):
121
+ new_tp_fp_fn = self.get_tp_fp_fn(
122
+ document=document,
123
+ annotation_filter=(
124
+ partial(has_one_of_the_labels, label_field=self.label_field, labels=self.labels)
125
+ if self.per_label and not self.infer_labels
126
+ else None
127
+ ),
128
+ annotation_processor=self.annotation_processor,
129
+ )
130
+ self.add_annotations(new_tp_fp_fn, label="MICRO")
131
+ if self.infer_labels:
132
+ layer = document[self.layer]
133
+ # collect labels from gold data and predictions
134
+ for ann in list(layer) + list(layer.predictions):
135
+ label = getattr(ann, self.label_field)
136
+ if label not in self.labels:
137
+ self.labels.append(label)
138
+ if self.per_label:
139
+ for label in self.labels:
140
+ new_tp_fp_fn = self.get_tp_fp_fn(
141
+ document=document,
142
+ annotation_filter=partial(
143
+ has_this_label, label_field=self.label_field, label=label
144
+ ),
145
+ annotation_processor=self.annotation_processor,
146
+ )
147
+ self.add_annotations(new_tp_fp_fn, label=label)
148
+
149
+ def format_texts(self, texts: List[str]) -> str:
150
+ return "<SEP>".join(texts)
151
+
152
+ def format_annotation(self, ann: Annotation) -> Dict[str, Any]:
153
+ if isinstance(ann, RelatedRelation):
154
+
155
+ head_resolved = ann.head.resolve()
156
+ tail_resolved = ann.tail.resolve()
157
+ ref_resolved = ann.reference_span.resolve()
158
+ return {
159
+ "related_label": ann.label,
160
+ "related_score": round(ann.score, 3),
161
+ "query_label": head_resolved[0],
162
+ "query_texts": self.format_texts(head_resolved[1]),
163
+ "query_score": round(ann.head.score, 3),
164
+ "ref_label": ref_resolved[0],
165
+ "ref_texts": self.format_texts(ref_resolved[1]),
166
+ "ref_score": round(ann.reference_span.score, 3),
167
+ "rec_label": tail_resolved[0],
168
+ "rec_texts": self.format_texts(tail_resolved[1]),
169
+ "rec_score": round(ann.tail.score, 3),
170
+ }
171
+ else:
172
+ raise NotImplementedError
173
+ # return ann.resolve()
174
+
175
+ def format_instance(self, instance: InstanceType) -> Dict[str, Any]:
176
+ document, annotation = instance
177
+ result = self.format_annotation(annotation)
178
+ if getattr(document, "id", None) is not None:
179
+ result["document_id"] = document.id
180
+ return result
181
+
182
+ def _compute(self) -> Dict[str, Dict[str, list]]:
183
+ res = dict()
184
+ for k, instances in self.tp_fp_fn.items():
185
+ res[k] = {
186
+ "tp": [self.format_instance(instance) for instance in instances[0]],
187
+ "fp": [self.format_instance(instance) for instance in instances[1]],
188
+ "fn": [self.format_instance(instance) for instance in instances[2]],
189
+ }
190
+
191
+ # if self.show_as_markdown:
192
+ # logger.info(f"\n{self.layer}:\n{pd.DataFrame(res).round(3).T.to_markdown()}")
193
+ return res
src/models/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
2
  from .sequence_classification_with_pooler import (
 
3
  SequencePairSimilarityModelWithMaxCosineSim,
4
- SequencePairSimilarityModelWithPooler2,
5
  SequencePairSimilarityModelWithPoolerAndAdapter,
6
  )
 
1
  from .sequence_classification import SimpleSequenceClassificationModelWithInputTypeIds
2
  from .sequence_classification_with_pooler import (
3
+ SequencePairSimilarityModelDummy,
4
  SequencePairSimilarityModelWithMaxCosineSim,
5
+ SequencePairSimilarityModelWithMaxCosineSimAndAdapter,
6
  SequencePairSimilarityModelWithPoolerAndAdapter,
7
  )
src/models/sequence_classification_with_pooler.py CHANGED
@@ -1,12 +1,11 @@
1
  import abc
2
  import logging
3
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  from adapters import AutoAdapterModel
8
  from pie_modules.models import SequencePairSimilarityModelWithPooler
9
- from pie_modules.models.components.pooler import MENTION_POOLING
10
  from pie_modules.models.sequence_classification_with_pooler import (
11
  InputType,
12
  OutputType,
@@ -20,31 +19,11 @@ from torch import FloatTensor, Tensor
20
  from transformers import AutoConfig, PreTrainedModel
21
  from transformers.modeling_outputs import SequenceClassifierOutput
22
 
23
- from src.models.components.pooler import SpanMeanPooler
24
-
25
  logger = logging.getLogger(__name__)
26
 
27
 
28
- class SequenceClassificationModelWithPoolerBase2(
29
- SequenceClassificationModelWithPoolerBase, abc.ABC
30
- ):
31
- def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
32
- aggregate = self.pooler_config.get("aggregate", "max")
33
- if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
34
- if aggregate == "mean":
35
- pooler_config = dict(self.pooler_config)
36
- pooler_config.pop("type")
37
- pooler_config.pop("aggregate")
38
- pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
39
- return pooler, pooler.output_dim
40
- else:
41
- raise ValueError(f"Unknown aggregation method: {aggregate}")
42
- else:
43
- return super().setup_pooler(input_dim)
44
-
45
-
46
  class SequenceClassificationModelWithPoolerAndAdapterBase(
47
- SequenceClassificationModelWithPoolerBase2, abc.ABC
48
  ):
49
  def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
50
  self.adapter_name_or_path = adapter_name_or_path
@@ -66,13 +45,6 @@ class SequenceClassificationModelWithPoolerAndAdapterBase(
66
  return model
67
 
68
 
69
- @PyTorchIEModel.register()
70
- class SequencePairSimilarityModelWithPooler2(
71
- SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
72
- ):
73
- pass
74
-
75
-
76
  @PyTorchIEModel.register()
77
  class SequencePairSimilarityModelWithPoolerAndAdapter(
78
  SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
@@ -164,3 +136,66 @@ class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
164
  SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
165
  ):
166
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
  import logging
3
+ from typing import Callable, List, Optional
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  from adapters import AutoAdapterModel
8
  from pie_modules.models import SequencePairSimilarityModelWithPooler
 
9
  from pie_modules.models.sequence_classification_with_pooler import (
10
  InputType,
11
  OutputType,
 
19
  from transformers import AutoConfig, PreTrainedModel
20
  from transformers.modeling_outputs import SequenceClassifierOutput
21
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class SequenceClassificationModelWithPoolerAndAdapterBase(
26
+ SequenceClassificationModelWithPoolerBase, abc.ABC
27
  ):
28
  def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
29
  self.adapter_name_or_path = adapter_name_or_path
 
45
  return model
46
 
47
 
 
 
 
 
 
 
 
48
  @PyTorchIEModel.register()
49
  class SequencePairSimilarityModelWithPoolerAndAdapter(
50
  SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
 
136
  SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
137
  ):
138
  pass
139
+
140
+
141
+ @PyTorchIEModel.register()
142
+ class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler):
143
+
144
+ def __init__(
145
+ self,
146
+ method: str = "random",
147
+ random_seed: Optional[int] = None,
148
+ **kwargs,
149
+ ):
150
+ self.method = method
151
+ self.random_seed = random_seed
152
+ super().__init__(**kwargs)
153
+
154
+ def setup_classifier(
155
+ self, pooler_output_dim: int
156
+ ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
157
+ if self.method == "random":
158
+ generator = torch.Generator(device=self.device)
159
+ if self.random_seed is not None:
160
+ generator = generator.manual_seed(self.random_seed)
161
+
162
+ def binary_classify_random(
163
+ inputs: torch.FloatTensor,
164
+ inputs_pair: torch.FloatTensor,
165
+ ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
166
+ """Randomly classifies pairs of inputs as similar or not similar."""
167
+ # Generate random logits in the range of [0, 1]
168
+ logits = torch.rand(inputs.size(0), device=self.device, generator=generator)
169
+ return logits
170
+
171
+ return binary_classify_random
172
+ elif self.method == "zero":
173
+
174
+ def binary_classify_zero(
175
+ inputs: torch.FloatTensor,
176
+ inputs_pair: torch.FloatTensor,
177
+ ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
178
+ """Classifies pairs of inputs as not similar (logit = 0)."""
179
+ # Return a tensor of zeros with the same batch size
180
+ logits = torch.zeros(inputs.size(0), device=self.device)
181
+ return logits
182
+
183
+ return binary_classify_zero
184
+ else:
185
+ raise ValueError(
186
+ f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'."
187
+ )
188
+
189
+ def setup_loss_fct(self) -> Callable:
190
+ def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor:
191
+ raise NotImplementedError(
192
+ "Dummy model does not support loss function, as it is not used for training."
193
+ )
194
+
195
+ return loss_fct
196
+
197
+ def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
198
+ # Just return a tensor of zeros in the shape of the batch size
199
+ # so that the classifier can construct dummy logits in the correct shape.
200
+ bs = pooler_inputs["start_indices"].size(0)
201
+ return torch.zeros(bs, device=self.device)
src/predict.py CHANGED
@@ -113,8 +113,8 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
113
  .to(dtype=pipeline.model.dtype)
114
  )
115
 
116
- # auto-convert the dataset if the metric specifies a document type
117
- dataset = pipeline.taskmodule.convert_dataset(dataset)
118
 
119
  # Init the serializer
120
  serializer: Optional[DocumentSerializer] = None
 
113
  .to(dtype=pipeline.model.dtype)
114
  )
115
 
116
+ # auto-convert the dataset if the taskmodule specifies a document type
117
+ dataset = dataset.to_document_type(pipeline.taskmodule, downcast=False)
118
 
119
  # Init the serializer
120
  serializer: Optional[DocumentSerializer] = None
src/serializer/__init__.py CHANGED
@@ -1 +1,4 @@
1
- from .json import JsonSerializer, JsonSerializer2
 
 
 
 
1
+ from .json import JsonSerializer
2
+
3
+ # backward compatibility
4
+ JsonSerializer2 = JsonSerializer
src/serializer/interface.py CHANGED
@@ -12,5 +12,4 @@ class DocumentSerializer(ABC):
12
  """
13
 
14
  @abstractmethod
15
- def __call__(self, documents: Iterable[Document]) -> Any:
16
- pass
 
12
  """
13
 
14
  @abstractmethod
15
+ def __call__(self, documents: Iterable[Document], append: bool = False, **kwargs) -> Any: ...
 
src/serializer/json.py CHANGED
@@ -1,11 +1,7 @@
1
- import json
2
- import os
3
- from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
4
 
5
  from pie_datasets import Dataset, DatasetDict, IterableDataset
6
- from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
7
  from pytorch_ie.core import Document
8
- from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
9
 
10
  from src.serializer.interface import DocumentSerializer
11
  from src.utils.logging_utils import get_pylogger
@@ -28,125 +24,13 @@ class JsonSerializer(DocumentSerializer):
28
  def __init__(self, **kwargs):
29
  self.default_kwargs = kwargs
30
 
31
- @classmethod
32
- def write(
33
- cls,
34
- documents: Iterable[Document],
35
- path: str,
36
- file_name: str = "documents.jsonl",
37
- metadata_file_name: str = METADATA_FILE_NAME,
38
- split: Optional[str] = None,
39
- **kwargs,
40
- ) -> Dict[str, str]:
41
- realpath = os.path.realpath(path)
42
- log.info(f'serialize documents to "{realpath}" ...')
43
- os.makedirs(realpath, exist_ok=True)
44
-
45
- if not isinstance(documents, Sequence):
46
- documents = list(documents)
47
-
48
- # dump metadata including the document_type
49
- if len(documents) == 0:
50
- raise Exception("cannot serialize empty list of documents")
51
- document_type = type(documents[0])
52
- metadata = {"document_type": serialize_document_type(document_type)}
53
- full_metadata_file_name = os.path.join(realpath, metadata_file_name)
54
- if os.path.exists(full_metadata_file_name):
55
- # load previous metadata
56
- with open(full_metadata_file_name) as f:
57
- previous_metadata = json.load(f)
58
- if previous_metadata != metadata:
59
- raise ValueError(
60
- f"metadata file {full_metadata_file_name} already exists, "
61
- "but the content does not match the current metadata"
62
- "\nprevious metadata: {previous_metadata}"
63
- "\ncurrent metadata: {metadata}"
64
- )
65
- else:
66
- with open(full_metadata_file_name, "w") as f:
67
- json.dump(metadata, f, indent=2)
68
-
69
- if split is not None:
70
- realpath = os.path.join(realpath, split)
71
- os.makedirs(realpath, exist_ok=True)
72
- full_file_name = os.path.join(realpath, file_name)
73
- if as_json_lines(file_name):
74
- # if the file already exists, append to it
75
- mode = "a" if os.path.exists(full_file_name) else "w"
76
- with open(full_file_name, mode) as f:
77
- for doc in documents:
78
- f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
79
- else:
80
- docs_list = [doc.asdict() for doc in documents]
81
- if os.path.exists(full_file_name):
82
- # load previous documents
83
- with open(full_file_name) as f:
84
- previous_doc_list = json.load(f)
85
- docs_list = previous_doc_list + docs_list
86
- with open(full_file_name, "w") as f:
87
- json.dump(docs_list, fp=f, **kwargs)
88
- return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
89
-
90
- @classmethod
91
- def read(
92
- cls,
93
- path: str,
94
- document_type: Optional[Type[D]] = None,
95
- file_name: str = "documents.jsonl",
96
- metadata_file_name: str = METADATA_FILE_NAME,
97
- split: Optional[str] = None,
98
- ) -> List[D]:
99
- realpath = os.path.realpath(path)
100
- log.info(f'load documents from "{realpath}" ...')
101
-
102
- # try to load metadata including the document_type
103
- full_metadata_file_name = os.path.join(realpath, metadata_file_name)
104
- if os.path.exists(full_metadata_file_name):
105
- with open(full_metadata_file_name) as f:
106
- metadata = json.load(f)
107
- document_type = resolve_optional_document_type(metadata.get("document_type"))
108
-
109
- if document_type is None:
110
- raise Exception("document_type is required to load serialized documents")
111
-
112
- if split is not None:
113
- realpath = os.path.join(realpath, split)
114
- full_file_name = os.path.join(realpath, file_name)
115
- documents = []
116
- if as_json_lines(str(file_name)):
117
- with open(full_file_name) as f:
118
- for line in f:
119
- json_dict = json.loads(line)
120
- documents.append(document_type.fromdict(json_dict))
121
- else:
122
- with open(full_file_name) as f:
123
- json_list = json.load(f)
124
- for json_dict in json_list:
125
- documents.append(document_type.fromdict(json_dict))
126
- return documents
127
-
128
- def read_with_defaults(self, **kwargs) -> List[D]:
129
- all_kwargs = {**self.default_kwargs, **kwargs}
130
- return self.read(**all_kwargs)
131
-
132
- def write_with_defaults(self, **kwargs) -> Dict[str, str]:
133
- all_kwargs = {**self.default_kwargs, **kwargs}
134
- return self.write(**all_kwargs)
135
-
136
- def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
137
- return self.write_with_defaults(documents=documents, **kwargs)
138
-
139
-
140
- class JsonSerializer2(DocumentSerializer):
141
- def __init__(self, **kwargs):
142
- self.default_kwargs = kwargs
143
-
144
  @classmethod
145
  def write(
146
  cls,
147
  documents: Iterable[Document],
148
  path: str,
149
  split: str = "train",
 
150
  ) -> Dict[str, str]:
151
  if not isinstance(documents, (Dataset, IterableDataset)):
152
  if not isinstance(documents, Sequence):
@@ -154,7 +38,7 @@ class JsonSerializer2(DocumentSerializer):
154
  else:
155
  documents = Dataset.from_documents(documents)
156
  dataset_dict = DatasetDict({split: documents})
157
- dataset_dict.to_json(path=path)
158
  return {"path": path, "split": split}
159
 
160
  @classmethod
@@ -181,5 +65,7 @@ class JsonSerializer2(DocumentSerializer):
181
  all_kwargs = {**self.default_kwargs, **kwargs}
182
  return self.write(**all_kwargs)
183
 
184
- def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
185
- return self.write_with_defaults(documents=documents, **kwargs)
 
 
 
1
+ from typing import Dict, Iterable, Optional, Sequence, Type, TypeVar
 
 
2
 
3
  from pie_datasets import Dataset, DatasetDict, IterableDataset
 
4
  from pytorch_ie.core import Document
 
5
 
6
  from src.serializer.interface import DocumentSerializer
7
  from src.utils.logging_utils import get_pylogger
 
24
  def __init__(self, **kwargs):
25
  self.default_kwargs = kwargs
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @classmethod
28
  def write(
29
  cls,
30
  documents: Iterable[Document],
31
  path: str,
32
  split: str = "train",
33
+ append: bool = False,
34
  ) -> Dict[str, str]:
35
  if not isinstance(documents, (Dataset, IterableDataset)):
36
  if not isinstance(documents, Sequence):
 
38
  else:
39
  documents = Dataset.from_documents(documents)
40
  dataset_dict = DatasetDict({split: documents})
41
+ dataset_dict.to_json(path=path, mode="a" if append else "w")
42
  return {"path": path, "split": split}
43
 
44
  @classmethod
 
65
  all_kwargs = {**self.default_kwargs, **kwargs}
66
  return self.write(**all_kwargs)
67
 
68
+ def __call__(
69
+ self, documents: Iterable[Document], append: bool = False, **kwargs
70
+ ) -> Dict[str, str]:
71
+ return self.write_with_defaults(documents=documents, append=append, **kwargs)
src/start_demo.py CHANGED
@@ -331,8 +331,9 @@ def main(cfg: DictConfig) -> None:
331
  visible=pdf_fulltext_extractor is not None,
332
  )
333
 
334
- enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get(
335
- "acl_anthology_pdf_dir"
 
336
  )
337
  acl_anthology_venues = gr.Textbox(
338
  label="ACL Anthology Venues",
 
331
  visible=pdf_fulltext_extractor is not None,
332
  )
333
 
334
+ enable_acl_venue_loading = (
335
+ pdf_fulltext_extractor is not None
336
+ and cfg.get("acl_anthology_data_dir") is not None
337
  )
338
  acl_anthology_venues = gr.Textbox(
339
  label="ACL Anthology Venues",
src/train.py CHANGED
@@ -45,7 +45,7 @@ from pie_modules.models import SimpleGenerativeModel
45
  from pie_modules.models.interface import RequiresTaskmoduleConfig
46
  from pie_modules.taskmodules import * # noqa: F403
47
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
48
- from pytorch_ie import Pipeline
49
  from pytorch_ie.core import PyTorchIEModel, TaskModule
50
  from pytorch_ie.models import * # noqa: F403
51
  from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
@@ -55,7 +55,6 @@ from pytorch_lightning import Callback, Trainer
55
  from pytorch_lightning.loggers import Logger
56
 
57
  from src import utils
58
- from src.datamodules import PieDataModule
59
  from src.models import * # noqa: F403
60
  from src.serializer.interface import DocumentSerializer
61
  from src.taskmodules import * # noqa: F403
@@ -135,7 +134,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
135
  )
136
 
137
  # auto-convert the dataset if the taskmodule specifies a document type
138
- dataset = taskmodule.convert_dataset(dataset)
139
 
140
  # Init pytorch-ie datamodule
141
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
 
45
  from pie_modules.models.interface import RequiresTaskmoduleConfig
46
  from pie_modules.taskmodules import * # noqa: F403
47
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
48
+ from pytorch_ie import PieDataModule, Pipeline
49
  from pytorch_ie.core import PyTorchIEModel, TaskModule
50
  from pytorch_ie.models import * # noqa: F403
51
  from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
 
55
  from pytorch_lightning.loggers import Logger
56
 
57
  from src import utils
 
58
  from src.models import * # noqa: F403
59
  from src.serializer.interface import DocumentSerializer
60
  from src.taskmodules import * # noqa: F403
 
134
  )
135
 
136
  # auto-convert the dataset if the taskmodule specifies a document type
137
+ dataset = dataset.to_document_type(taskmodule, downcast=False)
138
 
139
  # Init pytorch-ie datamodule
140
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
src/utils/graph_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Hashable, List, Optional, Sequence, TypeVar
2
+
3
+ from pytorch_ie.annotations import BinaryRelation
4
+
5
+ H = TypeVar("H", bound=Hashable)
6
+
7
+
8
+ def get_connected_components(
9
+ relations: Sequence[BinaryRelation],
10
+ elements: Optional[Sequence[H]] = None,
11
+ link_relation_label: Optional[str] = None,
12
+ link_relation_relation_score_threshold: Optional[float] = None,
13
+ add_singletons: bool = False,
14
+ ) -> List[List[H]]:
15
+ try:
16
+ import networkx as nx
17
+ except ImportError:
18
+ raise ImportError(
19
+ "NetworkX must be installed to use the SpansViaRelationMerger. "
20
+ "You can install NetworkX with `pip install networkx`."
21
+ )
22
+
23
+ # convert list of relations to a graph to easily calculate connected components to merge
24
+ g = nx.Graph()
25
+ link_relations = []
26
+ other_relations = []
27
+ elem2edge_relation = {}
28
+ for rel in relations:
29
+ if (link_relation_label is None or rel.label == link_relation_label) and (
30
+ link_relation_relation_score_threshold is None
31
+ or rel.score >= link_relation_relation_score_threshold
32
+ ):
33
+ link_relations.append(rel)
34
+ g.add_edge(rel.head, rel.tail)
35
+ elem2edge_relation[rel.head] = rel
36
+ elem2edge_relation[rel.tail] = rel
37
+ else:
38
+ other_relations.append(rel)
39
+
40
+ if add_singletons:
41
+ if elements is None:
42
+ raise ValueError("elements must be provided if add_singletons is True")
43
+ # add singletons to the graph
44
+ for elem in elements:
45
+ if elem not in elem2edge_relation:
46
+ g.add_node(elem)
47
+ return list(nx.connected_components(g))
src/utils/inference_utils.py CHANGED
@@ -50,6 +50,8 @@ def predict_and_serialize(
50
  batch_iter = [dataset]
51
  else:
52
  batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
 
 
53
  for docs_batch in batch_iter:
54
  if pipeline is not None:
55
  t_start = timeit.default_timer()
@@ -60,13 +62,14 @@ def predict_and_serialize(
60
  if serializer is not None:
61
  # the serializer should not return the serialized documents, but write them to disk
62
  # and instead return some metadata such as the path to the serialized documents
63
- serializer_result = serializer(docs_batch)
64
  if "serializer" in result and result["serializer"] != serializer_result:
65
  log.warning(
66
  f"serializer result changed from {result['serializer']} to {serializer_result}"
67
  " during prediction. Only the last result is returned."
68
  )
69
  result["serializer"] = serializer_result
 
70
 
71
  if prediction_time is not None:
72
  result["prediction_time"] = prediction_time
 
50
  batch_iter = [dataset]
51
  else:
52
  batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size)
53
+
54
+ append = False
55
  for docs_batch in batch_iter:
56
  if pipeline is not None:
57
  t_start = timeit.default_timer()
 
62
  if serializer is not None:
63
  # the serializer should not return the serialized documents, but write them to disk
64
  # and instead return some metadata such as the path to the serialized documents
65
+ serializer_result = serializer(docs_batch, append=append)
66
  if "serializer" in result and result["serializer"] != serializer_result:
67
  log.warning(
68
  f"serializer result changed from {result['serializer']} to {serializer_result}"
69
  " during prediction. Only the last result is returned."
70
  )
71
  result["serializer"] = serializer_result
72
+ append = True
73
 
74
  if prediction_time is not None:
75
  result["prediction_time"] = prediction_time
src/utils/pdf_utils/process_pdf.py CHANGED
@@ -138,7 +138,7 @@ def process_pdf_file(
138
  os.makedirs(output_dir, exist_ok=True)
139
 
140
  # get paper id as the name of the file
141
- paper_id = ".".join(input_file.split("/")[-1].split(".")[:-1])
142
  tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
143
  output_file = os.path.join(output_dir, f"{paper_id}.json")
144
 
 
138
  os.makedirs(output_dir, exist_ok=True)
139
 
140
  # get paper id as the name of the file
141
+ paper_id = os.path.splitext(os.path.basename(input_file))[0]
142
  tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
143
  output_file = os.path.join(output_dir, f"{paper_id}.json")
144