kpfadnis commited on
Commit
5db7074
·
1 Parent(s): 988e116

feat (aggregator): Added support for median aggregator.

Browse files

Signed-off-by: Kshitij Fadnis <[email protected]>

src/utilities/aggregators.ts CHANGED
@@ -24,11 +24,11 @@ import {
24
  AggregationStatistics,
25
  MetricValue,
26
  } from '@/src/types';
27
- import { castToNumber } from '@/src/utilities/metrics';
28
 
29
- export const averageAggregator: Aggregator = {
30
- name: 'averagae',
31
- displayName: 'Average',
32
  apply: (
33
  scores: number[] | string[],
34
  references: MetricValue[],
@@ -64,6 +64,50 @@ export const averageAggregator: Aggregator = {
64
  },
65
  };
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  export const majorityAggregator: Aggregator = {
68
  name: 'majority',
69
  displayName: 'Majority',
 
24
  AggregationStatistics,
25
  MetricValue,
26
  } from '@/src/types';
27
+ import { castToNumber, castToValue } from '@/src/utilities/metrics';
28
 
29
+ export const meanAggregator: Aggregator = {
30
+ name: 'mean',
31
+ displayName: 'Mean',
32
  apply: (
33
  scores: number[] | string[],
34
  references: MetricValue[],
 
64
  },
65
  };
66
 
67
+ export const medianAggregator: Aggregator = {
68
+ name: 'median',
69
+ displayName: 'Median',
70
+ apply: (
71
+ scores: number[] | string[],
72
+ references: MetricValue[],
73
+ ): AggregationStatistics => {
74
+ // Step 1: Cast score to numbers
75
+ const numericScores = scores.map((score) =>
76
+ typeof score === 'string' ? castToNumber(score, references) : score,
77
+ );
78
+
79
+ // Step 2: Sort the numeric scores
80
+ const sortedNumericScores = numericScores.toSorted();
81
+
82
+ // Step 3: Calculate aggregate value & standard deviation
83
+ const median =
84
+ sortedNumericScores.length % 2 == 0
85
+ ? sortedNumericScores[sortedNumericScores.length / 2]
86
+ : sortedNumericScores[(sortedNumericScores.length + 1) / 2];
87
+ const std = Math.sqrt(
88
+ sortedNumericScores
89
+ .map((score) => Math.pow(score - median, 2))
90
+ .reduce((a, b) => a + b) / sortedNumericScores.length,
91
+ );
92
+
93
+ // Step 4: Calculate confidence level
94
+ const sorted_counter = Object.entries(countBy(scores));
95
+ const numberOfUniqueValues = sorted_counter.length;
96
+ const mostCommonValueCount = sorted_counter[0][1];
97
+
98
+ return {
99
+ value: castToValue(median, references),
100
+ std: Math.round((std + Number.EPSILON) * 100) / 100,
101
+ confidence:
102
+ mostCommonValueCount === scores.length
103
+ ? AggregationConfidenceLevels.HIGH
104
+ : numberOfUniqueValues === scores.length
105
+ ? AggregationConfidenceLevels.LOW
106
+ : AggregationConfidenceLevels.MEDIUM,
107
+ };
108
+ },
109
+ };
110
+
111
  export const majorityAggregator: Aggregator = {
112
  name: 'majority',
113
  displayName: 'Majority',
src/utilities/metrics.ts CHANGED
@@ -77,6 +77,33 @@ export function extractMetricDisplayName(metric: Metric): string {
77
  : metric.name.charAt(0).toUpperCase() + metric.name.slice(1).toLowerCase();
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  export function castToNumber(
81
  value: string | number,
82
  references?: MetricValue[],
@@ -116,61 +143,121 @@ export function castToNumber(
116
  }
117
 
118
  /**
119
- * Compute average value
120
- * @param counter distribution of values
121
- * @param numberOfAnnotators number of annotators
122
  * @returns
123
  */
124
- function computeAverage(
125
  metric: Metric,
126
- counter: { [key: string]: number },
127
- numberOfAnnotators: number,
128
  ): { level: number; value: number | string } {
129
- // Step 0: Sort counter values
 
 
 
130
  const sorted_counter = Object.entries(counter);
131
  sorted_counter.sort((x, y) => {
132
  return y[1] - x[1];
133
  });
134
 
135
- // Step 1: Number of unique values, most common value and its count
136
  const numberOfUniqueValues = sorted_counter.length;
137
  const mostCommonValueCount = sorted_counter[0][1];
138
 
139
- // Step 2: Calculate average
140
  let sum: number = 0;
141
  for (const [value, count] of Object.entries(counter)) {
142
  sum +=
143
  (typeof value === 'string' ? castToNumber(value, metric.values) : value) *
144
  count;
145
  }
146
- const average =
147
- Math.round((sum / numberOfAnnotators + Number.EPSILON) * 100) / 100;
148
 
149
- // Step 3: Common patterns
150
- // Step 3.a: Absolute agreement
151
- if (mostCommonValueCount === numberOfAnnotators)
152
  return {
153
  level: AgreementLevels.ABSOLUTE_AGREEMENT,
154
- value: average,
155
  };
156
 
157
- // Step 3.b: Absolute disagreement/No agreement
158
- if (numberOfUniqueValues === numberOfAnnotators)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  return {
160
  level: AgreementLevels.NO_AGREEMENT,
161
- value: average,
162
  };
163
 
164
- // Step 4: Default return
165
  return {
166
  level: AgreementLevels.HIGH_AGREEMENT,
167
- value: average,
168
  };
169
  }
170
 
171
  /**
172
  * Compute majority value
173
- * @param metric
174
  * @param counter distribution of values
175
  * @param numberOfAnnotators number of annotators
176
  * @returns
@@ -257,8 +344,10 @@ export function calculateAggregateValue(
257
  let scores: string[] | number[] = Object.values(entries).map(
258
  (entry) => entry.value,
259
  );
260
- if (metric.aggregator === 'average') {
261
- return computeAverage(metric, countBy(scores), scores.length);
 
 
262
  } else {
263
  return computeMajority(metric, countBy(scores), scores.length);
264
  }
@@ -273,8 +362,10 @@ export function calculateAggregateValue(
273
  let scores: string[] | number[] = Object.values(entries).map(
274
  (entry) => entry.value,
275
  );
276
- if (metric.aggregator === 'average') {
277
- return computeAverage(metric, countBy(scores), scores.length);
 
 
278
  } else {
279
  return computeMajority(metric, countBy(scores), scores.length);
280
  }
 
77
  : metric.name.charAt(0).toUpperCase() + metric.name.slice(1).toLowerCase();
78
  }
79
 
80
+ /**
81
+ * Converts numeric value to metric value using references in case of 'categorical' metrics
82
+ * @param value numeric value to convert
83
+ * @param references reference metric values
84
+ * @returns metric value
85
+ */
86
+ export function castToValue(
87
+ value: number,
88
+ references?: MetricValue[],
89
+ ): string | number {
90
+ // Step 1: Check if references are provided to convert "numeric" value to "string" value
91
+ if (references) {
92
+ // Step 1.a: Find appropriate reference by comparing "string" values
93
+ const reference = references.find((entry) => entry.numericValue === value);
94
+
95
+ // Step 1.b: If value exists in reference, then return it
96
+ if (reference && reference.value) {
97
+ return reference.value;
98
+ } else {
99
+ return value;
100
+ }
101
+ }
102
+
103
+ // Default return
104
+ return value;
105
+ }
106
+
107
  export function castToNumber(
108
  value: string | number,
109
  references?: MetricValue[],
 
143
  }
144
 
145
  /**
146
+ * Compute mean value
147
+ * @param metric metric under consideration
148
+ * @param scores distribution of values
149
  * @returns
150
  */
151
+ function computeMean(
152
  metric: Metric,
153
+ scores: string[] | number[],
 
154
  ): { level: number; value: number | string } {
155
+ // Step 1: Create counter
156
+ const counter: { [key: string]: number } = countBy(scores);
157
+
158
+ // Step 2: Sort counter values
159
  const sorted_counter = Object.entries(counter);
160
  sorted_counter.sort((x, y) => {
161
  return y[1] - x[1];
162
  });
163
 
164
+ // Step 3: Number of unique values, most common value and its count
165
  const numberOfUniqueValues = sorted_counter.length;
166
  const mostCommonValueCount = sorted_counter[0][1];
167
 
168
+ // Step 4: Calculate mean
169
  let sum: number = 0;
170
  for (const [value, count] of Object.entries(counter)) {
171
  sum +=
172
  (typeof value === 'string' ? castToNumber(value, metric.values) : value) *
173
  count;
174
  }
175
+ const mean = Math.round((sum / scores.length + Number.EPSILON) * 100) / 100;
 
176
 
177
+ // Step 5: Common patterns
178
+ // Step 5.a: Absolute agreement
179
+ if (mostCommonValueCount === scores.length)
180
  return {
181
  level: AgreementLevels.ABSOLUTE_AGREEMENT,
182
+ value: mean,
183
  };
184
 
185
+ // Step 5.b: Absolute disagreement/No agreement
186
+ if (numberOfUniqueValues === scores.length)
187
+ return {
188
+ level: AgreementLevels.NO_AGREEMENT,
189
+ value: mean,
190
+ };
191
+
192
+ // Step 6: Default return
193
+ return {
194
+ level: AgreementLevels.HIGH_AGREEMENT,
195
+ value: mean,
196
+ };
197
+ }
198
+
199
+ /**
200
+ * Compute median value
201
+ * @param metric metric under consideration
202
+ * @param counter distribution of values
203
+ * @returns
204
+ */
205
+ function computeMedian(
206
+ metric: Metric,
207
+ scores: string[] | number[],
208
+ ): { level: number; value: number | string } {
209
+ // Step 1: Create counter
210
+ const counter: { [key: string]: number } = countBy(scores);
211
+
212
+ // Step 2: Sort counter values
213
+ const sorted_counter = Object.entries(counter);
214
+ sorted_counter.sort((x, y) => {
215
+ return y[1] - x[1];
216
+ });
217
+
218
+ // Step 3: Number of unique values, most common value and its count
219
+ const numberOfUniqueValues = sorted_counter.length;
220
+ const mostCommonValueCount = sorted_counter[0][1];
221
+
222
+ // Step 4: Cast score to numbers
223
+ const numericScores = scores.map((score) =>
224
+ typeof score === 'string' ? castToNumber(score, metric.values) : score,
225
+ );
226
+
227
+ // Step 5: Sort the numeric scores
228
+ const sortedNumericScores = numericScores.toSorted();
229
+
230
+ // Step 6: Calculate median
231
+ const median =
232
+ sortedNumericScores.length % 2 == 0
233
+ ? sortedNumericScores[sortedNumericScores.length / 2]
234
+ : sortedNumericScores[(sortedNumericScores.length + 1) / 2 - 1];
235
+
236
+ // Step 7: Common patterns
237
+ // Step 7.a: Absolute agreement
238
+ if (mostCommonValueCount === scores.length)
239
+ return {
240
+ level: AgreementLevels.ABSOLUTE_AGREEMENT,
241
+ value: castToValue(median, metric.values),
242
+ };
243
+
244
+ // Step 7.b: Absolute disagreement/No agreement
245
+ if (numberOfUniqueValues === scores.length)
246
  return {
247
  level: AgreementLevels.NO_AGREEMENT,
248
+ value: castToValue(median, metric.values),
249
  };
250
 
251
+ // Step 8: Default return
252
  return {
253
  level: AgreementLevels.HIGH_AGREEMENT,
254
+ value: castToValue(median, metric.values),
255
  };
256
  }
257
 
258
  /**
259
  * Compute majority value
260
+ * @param metric metric under consideration
261
  * @param counter distribution of values
262
  * @param numberOfAnnotators number of annotators
263
  * @returns
 
344
  let scores: string[] | number[] = Object.values(entries).map(
345
  (entry) => entry.value,
346
  );
347
+ if (metric.aggregator === 'average' || metric.aggregator === 'mean') {
348
+ return computeMean(metric, scores);
349
+ } else if (metric.aggregator === 'median') {
350
+ return computeMedian(metric, scores);
351
  } else {
352
  return computeMajority(metric, countBy(scores), scores.length);
353
  }
 
362
  let scores: string[] | number[] = Object.values(entries).map(
363
  (entry) => entry.value,
364
  );
365
+ if (metric.aggregator === 'average' || metric.aggregator === 'mean') {
366
+ return computeMean(metric, scores);
367
+ } else if (metric.aggregator === 'median') {
368
+ return computeMedian(metric, scores);
369
  } else {
370
  return computeMajority(metric, countBy(scores), scores.length);
371
  }
src/views/model-behavior/ModelBehavior.tsx CHANGED
@@ -140,7 +140,8 @@ function prepareGroupBarChartData(
140
  return {
141
  ...entry,
142
  key:
143
- metric.aggregator && metric.aggregator === 'majority'
 
144
  ? extractMetricDisplayValue(entry.key, metric.values)
145
  : entry.key,
146
  };
 
140
  return {
141
  ...entry,
142
  key:
143
+ metric.aggregator &&
144
+ (metric.aggregator === 'majority' || metric.aggregator === 'median')
145
  ? extractMetricDisplayValue(entry.key, metric.values)
146
  : entry.key,
147
  };
src/views/performance-overview/PerformanceOverview.tsx CHANGED
@@ -53,7 +53,8 @@ import {
53
  castToNumber,
54
  } from '@/src/utilities/metrics';
55
  import {
56
- averageAggregator,
 
57
  majorityAggregator,
58
  } from '@/src/utilities/aggregators';
59
  import { areObjectsIntersecting } from '@/src/utilities/objects';
@@ -372,7 +373,11 @@ export default function PerformanceOverview({
372
  const [WindowHeight, setWindowHeight] = useState<number>(
373
  global?.window && window.innerHeight,
374
  );
375
- const aggregators: Aggregator[] = [averageAggregator, majorityAggregator];
 
 
 
 
376
  const [selectedAggregators, setSelectedAggregators] = useState<{
377
  [key: string]: Aggregator;
378
  }>(
@@ -383,7 +388,9 @@ export default function PerformanceOverview({
383
  metric.name,
384
  metric.aggregator === 'majority'
385
  ? majorityAggregator
386
- : averageAggregator,
 
 
387
  ]),
388
  ),
389
  );
@@ -460,7 +467,7 @@ export default function PerformanceOverview({
460
  for (const [metric, evaluations] of Object.entries(
461
  evaluationsPerMetric,
462
  )) {
463
- const aggregator = selectedAggregators[metric] || averageAggregator;
464
 
465
  // Select evaluations based on selected filters
466
  const selectedEvaluations = !isEmpty(selectedFilters)
 
53
  castToNumber,
54
  } from '@/src/utilities/metrics';
55
  import {
56
+ meanAggregator,
57
+ medianAggregator,
58
  majorityAggregator,
59
  } from '@/src/utilities/aggregators';
60
  import { areObjectsIntersecting } from '@/src/utilities/objects';
 
373
  const [WindowHeight, setWindowHeight] = useState<number>(
374
  global?.window && window.innerHeight,
375
  );
376
+ const aggregators: Aggregator[] = [
377
+ meanAggregator,
378
+ medianAggregator,
379
+ majorityAggregator,
380
+ ];
381
  const [selectedAggregators, setSelectedAggregators] = useState<{
382
  [key: string]: Aggregator;
383
  }>(
 
388
  metric.name,
389
  metric.aggregator === 'majority'
390
  ? majorityAggregator
391
+ : metric.aggregator === 'median'
392
+ ? medianAggregator
393
+ : meanAggregator,
394
  ]),
395
  ),
396
  );
 
467
  for (const [metric, evaluations] of Object.entries(
468
  evaluationsPerMetric,
469
  )) {
470
+ const aggregator = selectedAggregators[metric] || meanAggregator;
471
 
472
  // Select evaluations based on selected filters
473
  const selectedEvaluations = !isEmpty(selectedFilters)