Spaces:
Running
Running
feat (aggregator): Added support for median aggregator.
Browse filesSigned-off-by: Kshitij Fadnis <[email protected]>
src/utilities/aggregators.ts
CHANGED
@@ -24,11 +24,11 @@ import {
|
|
24 |
AggregationStatistics,
|
25 |
MetricValue,
|
26 |
} from '@/src/types';
|
27 |
-
import { castToNumber } from '@/src/utilities/metrics';
|
28 |
|
29 |
-
export const
|
30 |
-
name: '
|
31 |
-
displayName: '
|
32 |
apply: (
|
33 |
scores: number[] | string[],
|
34 |
references: MetricValue[],
|
@@ -64,6 +64,50 @@ export const averageAggregator: Aggregator = {
|
|
64 |
},
|
65 |
};
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
export const majorityAggregator: Aggregator = {
|
68 |
name: 'majority',
|
69 |
displayName: 'Majority',
|
|
|
24 |
AggregationStatistics,
|
25 |
MetricValue,
|
26 |
} from '@/src/types';
|
27 |
+
import { castToNumber, castToValue } from '@/src/utilities/metrics';
|
28 |
|
29 |
+
export const meanAggregator: Aggregator = {
|
30 |
+
name: 'mean',
|
31 |
+
displayName: 'Mean',
|
32 |
apply: (
|
33 |
scores: number[] | string[],
|
34 |
references: MetricValue[],
|
|
|
64 |
},
|
65 |
};
|
66 |
|
67 |
+
export const medianAggregator: Aggregator = {
|
68 |
+
name: 'median',
|
69 |
+
displayName: 'Median',
|
70 |
+
apply: (
|
71 |
+
scores: number[] | string[],
|
72 |
+
references: MetricValue[],
|
73 |
+
): AggregationStatistics => {
|
74 |
+
// Step 1: Cast score to numbers
|
75 |
+
const numericScores = scores.map((score) =>
|
76 |
+
typeof score === 'string' ? castToNumber(score, references) : score,
|
77 |
+
);
|
78 |
+
|
79 |
+
// Step 2: Sort the numeric scores
|
80 |
+
const sortedNumericScores = numericScores.toSorted();
|
81 |
+
|
82 |
+
// Step 3: Calculate aggregate value & standard deviation
|
83 |
+
const median =
|
84 |
+
sortedNumericScores.length % 2 == 0
|
85 |
+
? sortedNumericScores[sortedNumericScores.length / 2]
|
86 |
+
: sortedNumericScores[(sortedNumericScores.length + 1) / 2];
|
87 |
+
const std = Math.sqrt(
|
88 |
+
sortedNumericScores
|
89 |
+
.map((score) => Math.pow(score - median, 2))
|
90 |
+
.reduce((a, b) => a + b) / sortedNumericScores.length,
|
91 |
+
);
|
92 |
+
|
93 |
+
// Step 4: Calculate confidence level
|
94 |
+
const sorted_counter = Object.entries(countBy(scores));
|
95 |
+
const numberOfUniqueValues = sorted_counter.length;
|
96 |
+
const mostCommonValueCount = sorted_counter[0][1];
|
97 |
+
|
98 |
+
return {
|
99 |
+
value: castToValue(median, references),
|
100 |
+
std: Math.round((std + Number.EPSILON) * 100) / 100,
|
101 |
+
confidence:
|
102 |
+
mostCommonValueCount === scores.length
|
103 |
+
? AggregationConfidenceLevels.HIGH
|
104 |
+
: numberOfUniqueValues === scores.length
|
105 |
+
? AggregationConfidenceLevels.LOW
|
106 |
+
: AggregationConfidenceLevels.MEDIUM,
|
107 |
+
};
|
108 |
+
},
|
109 |
+
};
|
110 |
+
|
111 |
export const majorityAggregator: Aggregator = {
|
112 |
name: 'majority',
|
113 |
displayName: 'Majority',
|
src/utilities/metrics.ts
CHANGED
@@ -77,6 +77,33 @@ export function extractMetricDisplayName(metric: Metric): string {
|
|
77 |
: metric.name.charAt(0).toUpperCase() + metric.name.slice(1).toLowerCase();
|
78 |
}
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
export function castToNumber(
|
81 |
value: string | number,
|
82 |
references?: MetricValue[],
|
@@ -116,61 +143,121 @@ export function castToNumber(
|
|
116 |
}
|
117 |
|
118 |
/**
|
119 |
-
* Compute
|
120 |
-
* @param
|
121 |
-
* @param
|
122 |
* @returns
|
123 |
*/
|
124 |
-
function
|
125 |
metric: Metric,
|
126 |
-
|
127 |
-
numberOfAnnotators: number,
|
128 |
): { level: number; value: number | string } {
|
129 |
-
// Step
|
|
|
|
|
|
|
130 |
const sorted_counter = Object.entries(counter);
|
131 |
sorted_counter.sort((x, y) => {
|
132 |
return y[1] - x[1];
|
133 |
});
|
134 |
|
135 |
-
// Step
|
136 |
const numberOfUniqueValues = sorted_counter.length;
|
137 |
const mostCommonValueCount = sorted_counter[0][1];
|
138 |
|
139 |
-
// Step
|
140 |
let sum: number = 0;
|
141 |
for (const [value, count] of Object.entries(counter)) {
|
142 |
sum +=
|
143 |
(typeof value === 'string' ? castToNumber(value, metric.values) : value) *
|
144 |
count;
|
145 |
}
|
146 |
-
const
|
147 |
-
Math.round((sum / numberOfAnnotators + Number.EPSILON) * 100) / 100;
|
148 |
|
149 |
-
// Step
|
150 |
-
// Step
|
151 |
-
if (mostCommonValueCount ===
|
152 |
return {
|
153 |
level: AgreementLevels.ABSOLUTE_AGREEMENT,
|
154 |
-
value:
|
155 |
};
|
156 |
|
157 |
-
// Step
|
158 |
-
if (numberOfUniqueValues ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
return {
|
160 |
level: AgreementLevels.NO_AGREEMENT,
|
161 |
-
value:
|
162 |
};
|
163 |
|
164 |
-
// Step
|
165 |
return {
|
166 |
level: AgreementLevels.HIGH_AGREEMENT,
|
167 |
-
value:
|
168 |
};
|
169 |
}
|
170 |
|
171 |
/**
|
172 |
* Compute majority value
|
173 |
-
* @param metric
|
174 |
* @param counter distribution of values
|
175 |
* @param numberOfAnnotators number of annotators
|
176 |
* @returns
|
@@ -257,8 +344,10 @@ export function calculateAggregateValue(
|
|
257 |
let scores: string[] | number[] = Object.values(entries).map(
|
258 |
(entry) => entry.value,
|
259 |
);
|
260 |
-
if (metric.aggregator === 'average') {
|
261 |
-
return
|
|
|
|
|
262 |
} else {
|
263 |
return computeMajority(metric, countBy(scores), scores.length);
|
264 |
}
|
@@ -273,8 +362,10 @@ export function calculateAggregateValue(
|
|
273 |
let scores: string[] | number[] = Object.values(entries).map(
|
274 |
(entry) => entry.value,
|
275 |
);
|
276 |
-
if (metric.aggregator === 'average') {
|
277 |
-
return
|
|
|
|
|
278 |
} else {
|
279 |
return computeMajority(metric, countBy(scores), scores.length);
|
280 |
}
|
|
|
77 |
: metric.name.charAt(0).toUpperCase() + metric.name.slice(1).toLowerCase();
|
78 |
}
|
79 |
|
80 |
+
/**
|
81 |
+
* Converts numeric value to metric value using references in case of 'categorical' metrics
|
82 |
+
* @param value numeric value to convert
|
83 |
+
* @param references reference metric values
|
84 |
+
* @returns metric value
|
85 |
+
*/
|
86 |
+
export function castToValue(
|
87 |
+
value: number,
|
88 |
+
references?: MetricValue[],
|
89 |
+
): string | number {
|
90 |
+
// Step 1: Check if references are provided to convert "numeric" value to "string" value
|
91 |
+
if (references) {
|
92 |
+
// Step 1.a: Find appropriate reference by comparing "string" values
|
93 |
+
const reference = references.find((entry) => entry.numericValue === value);
|
94 |
+
|
95 |
+
// Step 1.b: If value exists in reference, then return it
|
96 |
+
if (reference && reference.value) {
|
97 |
+
return reference.value;
|
98 |
+
} else {
|
99 |
+
return value;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
// Default return
|
104 |
+
return value;
|
105 |
+
}
|
106 |
+
|
107 |
export function castToNumber(
|
108 |
value: string | number,
|
109 |
references?: MetricValue[],
|
|
|
143 |
}
|
144 |
|
145 |
/**
|
146 |
+
* Compute mean value
|
147 |
+
* @param metric metric under consideration
|
148 |
+
* @param scores distribution of values
|
149 |
* @returns
|
150 |
*/
|
151 |
+
function computeMean(
|
152 |
metric: Metric,
|
153 |
+
scores: string[] | number[],
|
|
|
154 |
): { level: number; value: number | string } {
|
155 |
+
// Step 1: Create counter
|
156 |
+
const counter: { [key: string]: number } = countBy(scores);
|
157 |
+
|
158 |
+
// Step 2: Sort counter values
|
159 |
const sorted_counter = Object.entries(counter);
|
160 |
sorted_counter.sort((x, y) => {
|
161 |
return y[1] - x[1];
|
162 |
});
|
163 |
|
164 |
+
// Step 3: Number of unique values, most common value and its count
|
165 |
const numberOfUniqueValues = sorted_counter.length;
|
166 |
const mostCommonValueCount = sorted_counter[0][1];
|
167 |
|
168 |
+
// Step 4: Calculate mean
|
169 |
let sum: number = 0;
|
170 |
for (const [value, count] of Object.entries(counter)) {
|
171 |
sum +=
|
172 |
(typeof value === 'string' ? castToNumber(value, metric.values) : value) *
|
173 |
count;
|
174 |
}
|
175 |
+
const mean = Math.round((sum / scores.length + Number.EPSILON) * 100) / 100;
|
|
|
176 |
|
177 |
+
// Step 5: Common patterns
|
178 |
+
// Step 5.a: Absolute agreement
|
179 |
+
if (mostCommonValueCount === scores.length)
|
180 |
return {
|
181 |
level: AgreementLevels.ABSOLUTE_AGREEMENT,
|
182 |
+
value: mean,
|
183 |
};
|
184 |
|
185 |
+
// Step 5.b: Absolute disagreement/No agreement
|
186 |
+
if (numberOfUniqueValues === scores.length)
|
187 |
+
return {
|
188 |
+
level: AgreementLevels.NO_AGREEMENT,
|
189 |
+
value: mean,
|
190 |
+
};
|
191 |
+
|
192 |
+
// Step 6: Default return
|
193 |
+
return {
|
194 |
+
level: AgreementLevels.HIGH_AGREEMENT,
|
195 |
+
value: mean,
|
196 |
+
};
|
197 |
+
}
|
198 |
+
|
199 |
+
/**
|
200 |
+
* Compute median value
|
201 |
+
* @param metric metric under consideration
|
202 |
+
* @param counter distribution of values
|
203 |
+
* @returns
|
204 |
+
*/
|
205 |
+
function computeMedian(
|
206 |
+
metric: Metric,
|
207 |
+
scores: string[] | number[],
|
208 |
+
): { level: number; value: number | string } {
|
209 |
+
// Step 1: Create counter
|
210 |
+
const counter: { [key: string]: number } = countBy(scores);
|
211 |
+
|
212 |
+
// Step 2: Sort counter values
|
213 |
+
const sorted_counter = Object.entries(counter);
|
214 |
+
sorted_counter.sort((x, y) => {
|
215 |
+
return y[1] - x[1];
|
216 |
+
});
|
217 |
+
|
218 |
+
// Step 3: Number of unique values, most common value and its count
|
219 |
+
const numberOfUniqueValues = sorted_counter.length;
|
220 |
+
const mostCommonValueCount = sorted_counter[0][1];
|
221 |
+
|
222 |
+
// Step 4: Cast score to numbers
|
223 |
+
const numericScores = scores.map((score) =>
|
224 |
+
typeof score === 'string' ? castToNumber(score, metric.values) : score,
|
225 |
+
);
|
226 |
+
|
227 |
+
// Step 5: Sort the numeric scores
|
228 |
+
const sortedNumericScores = numericScores.toSorted();
|
229 |
+
|
230 |
+
// Step 6: Calculate median
|
231 |
+
const median =
|
232 |
+
sortedNumericScores.length % 2 == 0
|
233 |
+
? sortedNumericScores[sortedNumericScores.length / 2]
|
234 |
+
: sortedNumericScores[(sortedNumericScores.length + 1) / 2 - 1];
|
235 |
+
|
236 |
+
// Step 7: Common patterns
|
237 |
+
// Step 7.a: Absolute agreement
|
238 |
+
if (mostCommonValueCount === scores.length)
|
239 |
+
return {
|
240 |
+
level: AgreementLevels.ABSOLUTE_AGREEMENT,
|
241 |
+
value: castToValue(median, metric.values),
|
242 |
+
};
|
243 |
+
|
244 |
+
// Step 7.b: Absolute disagreement/No agreement
|
245 |
+
if (numberOfUniqueValues === scores.length)
|
246 |
return {
|
247 |
level: AgreementLevels.NO_AGREEMENT,
|
248 |
+
value: castToValue(median, metric.values),
|
249 |
};
|
250 |
|
251 |
+
// Step 8: Default return
|
252 |
return {
|
253 |
level: AgreementLevels.HIGH_AGREEMENT,
|
254 |
+
value: castToValue(median, metric.values),
|
255 |
};
|
256 |
}
|
257 |
|
258 |
/**
|
259 |
* Compute majority value
|
260 |
+
* @param metric metric under consideration
|
261 |
* @param counter distribution of values
|
262 |
* @param numberOfAnnotators number of annotators
|
263 |
* @returns
|
|
|
344 |
let scores: string[] | number[] = Object.values(entries).map(
|
345 |
(entry) => entry.value,
|
346 |
);
|
347 |
+
if (metric.aggregator === 'average' || metric.aggregator === 'mean') {
|
348 |
+
return computeMean(metric, scores);
|
349 |
+
} else if (metric.aggregator === 'median') {
|
350 |
+
return computeMedian(metric, scores);
|
351 |
} else {
|
352 |
return computeMajority(metric, countBy(scores), scores.length);
|
353 |
}
|
|
|
362 |
let scores: string[] | number[] = Object.values(entries).map(
|
363 |
(entry) => entry.value,
|
364 |
);
|
365 |
+
if (metric.aggregator === 'average' || metric.aggregator === 'mean') {
|
366 |
+
return computeMean(metric, scores);
|
367 |
+
} else if (metric.aggregator === 'median') {
|
368 |
+
return computeMedian(metric, scores);
|
369 |
} else {
|
370 |
return computeMajority(metric, countBy(scores), scores.length);
|
371 |
}
|
src/views/model-behavior/ModelBehavior.tsx
CHANGED
@@ -140,7 +140,8 @@ function prepareGroupBarChartData(
|
|
140 |
return {
|
141 |
...entry,
|
142 |
key:
|
143 |
-
metric.aggregator &&
|
|
|
144 |
? extractMetricDisplayValue(entry.key, metric.values)
|
145 |
: entry.key,
|
146 |
};
|
|
|
140 |
return {
|
141 |
...entry,
|
142 |
key:
|
143 |
+
metric.aggregator &&
|
144 |
+
(metric.aggregator === 'majority' || metric.aggregator === 'median')
|
145 |
? extractMetricDisplayValue(entry.key, metric.values)
|
146 |
: entry.key,
|
147 |
};
|
src/views/performance-overview/PerformanceOverview.tsx
CHANGED
@@ -53,7 +53,8 @@ import {
|
|
53 |
castToNumber,
|
54 |
} from '@/src/utilities/metrics';
|
55 |
import {
|
56 |
-
|
|
|
57 |
majorityAggregator,
|
58 |
} from '@/src/utilities/aggregators';
|
59 |
import { areObjectsIntersecting } from '@/src/utilities/objects';
|
@@ -372,7 +373,11 @@ export default function PerformanceOverview({
|
|
372 |
const [WindowHeight, setWindowHeight] = useState<number>(
|
373 |
global?.window && window.innerHeight,
|
374 |
);
|
375 |
-
const aggregators: Aggregator[] = [
|
|
|
|
|
|
|
|
|
376 |
const [selectedAggregators, setSelectedAggregators] = useState<{
|
377 |
[key: string]: Aggregator;
|
378 |
}>(
|
@@ -383,7 +388,9 @@ export default function PerformanceOverview({
|
|
383 |
metric.name,
|
384 |
metric.aggregator === 'majority'
|
385 |
? majorityAggregator
|
386 |
-
:
|
|
|
|
|
387 |
]),
|
388 |
),
|
389 |
);
|
@@ -460,7 +467,7 @@ export default function PerformanceOverview({
|
|
460 |
for (const [metric, evaluations] of Object.entries(
|
461 |
evaluationsPerMetric,
|
462 |
)) {
|
463 |
-
const aggregator = selectedAggregators[metric] ||
|
464 |
|
465 |
// Select evaluations based on selected filters
|
466 |
const selectedEvaluations = !isEmpty(selectedFilters)
|
|
|
53 |
castToNumber,
|
54 |
} from '@/src/utilities/metrics';
|
55 |
import {
|
56 |
+
meanAggregator,
|
57 |
+
medianAggregator,
|
58 |
majorityAggregator,
|
59 |
} from '@/src/utilities/aggregators';
|
60 |
import { areObjectsIntersecting } from '@/src/utilities/objects';
|
|
|
373 |
const [WindowHeight, setWindowHeight] = useState<number>(
|
374 |
global?.window && window.innerHeight,
|
375 |
);
|
376 |
+
const aggregators: Aggregator[] = [
|
377 |
+
meanAggregator,
|
378 |
+
medianAggregator,
|
379 |
+
majorityAggregator,
|
380 |
+
];
|
381 |
const [selectedAggregators, setSelectedAggregators] = useState<{
|
382 |
[key: string]: Aggregator;
|
383 |
}>(
|
|
|
388 |
metric.name,
|
389 |
metric.aggregator === 'majority'
|
390 |
? majorityAggregator
|
391 |
+
: metric.aggregator === 'median'
|
392 |
+
? medianAggregator
|
393 |
+
: meanAggregator,
|
394 |
]),
|
395 |
),
|
396 |
);
|
|
|
467 |
for (const [metric, evaluations] of Object.entries(
|
468 |
evaluationsPerMetric,
|
469 |
)) {
|
470 |
+
const aggregator = selectedAggregators[metric] || meanAggregator;
|
471 |
|
472 |
// Select evaluations based on selected filters
|
473 |
const selectedEvaluations = !isEmpty(selectedFilters)
|