Spaces:
Running
Running
/** | |
* | |
* Copyright 2023-2024 InspectorRAGet Team | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* | |
**/ | |
import { countBy, isNumber } from 'lodash'; | |
import { Metric, MetricValue } from '@/src/types'; | |
export const MetricDefinitions = { | |
coherence: 'The response is coherent, natural, and not dismissive.', | |
naturalness: 'The response is coherent, natural, and not dismissive.', | |
specificity: | |
'The response provides appropriate amount of useful information.', | |
appropriateness: | |
'The response provides appropriate amount of useful information.', | |
faithfulness: 'The response is faithful and grounded on the context.', | |
feedback: | |
"Annotator's comments about quality of response, potential issues etc.", | |
}; | |
export const AgreementLevels = { | |
ABSOLUTE_AGREEMENT: 3, | |
HIGH_AGREEMENT: 2, | |
LOW_AGREEMENT: 1, | |
NO_AGREEMENT: 0, | |
}; | |
export const AgreementLevelDefinitions = { | |
Absolute: 'All annotators selected a same value for a given metric.', | |
High: 'Majority of annotators selected a same value for a given metric and the most common value and the 2nd most common value were less that 2 units apart.', | |
Low: 'Majority of annotators selected a same value for a given metric.', | |
No: 'Majority of annotators selected different values for a given metric.', | |
}; | |
export function extractMetricDisplayValue( | |
value: string | number, | |
references?: MetricValue[], | |
): string { | |
// If value is of type "string" | |
if (typeof value === 'string') { | |
// Step 1: Check if references are provided to convert "string" value to "numeric" value | |
if (references) { | |
// Step 1.a: Find appropriate reference by comparing "string" values | |
const reference = references.find((entry) => entry.value === value); | |
// Step 1.b: If numeric value exists in reference, then return it | |
if (reference && reference.displayValue) { | |
return reference.displayValue; | |
} else { | |
return value; | |
} | |
} else { | |
return value; | |
} | |
} else { | |
// Value is of type "number" | |
return parseFloat(value.toFixed(2)).toString(); | |
} | |
} | |
export function extractMetricDisplayName(metric: Metric): string { | |
return metric.displayName | |
? metric.displayName | |
: metric.name.charAt(0).toUpperCase() + metric.name.slice(1).toLowerCase(); | |
} | |
export function castToNumber( | |
value: string | number, | |
references?: MetricValue[], | |
): number { | |
// If value is of type "string" | |
if (typeof value === 'string') { | |
// Step 1: Check if references are provided to convert "string" value to "numeric" value | |
if (references) { | |
// Step 1.a: Find appropriate reference by comparing "string" values | |
const reference = references.find((entry) => entry.value === value); | |
// Step 1.b: If numeric value exists in reference, then return it | |
if ( | |
reference && | |
reference.hasOwnProperty('numericValue') && | |
typeof reference.numericValue === 'number' | |
) { | |
return reference.numericValue; | |
} else { | |
return parseInt(value); | |
} | |
} | |
// Step 2: Cast to int, if references are absent | |
else if (value === 'N/A' || value === '') { | |
return 0; | |
} else { | |
return parseFloat(value); | |
} | |
} | |
// Value is of type "number" | |
else { | |
return value; | |
} | |
} | |
/** | |
* Compute average value | |
* @param counter distribution of values | |
* @param numberOfAnnotators number of annotators | |
* @returns | |
*/ | |
function computeAverage( | |
metric: Metric, | |
counter: { [key: string]: number }, | |
numberOfAnnotators: number, | |
): { level: number; value: number | string } { | |
// Step 0: Sort counter values | |
const sorted_counter = Object.entries(counter); | |
sorted_counter.sort((x, y) => { | |
return y[1] - x[1]; | |
}); | |
// Step 1: Number of unique values, most common value and its count | |
const numberOfUniqueValues = sorted_counter.length; | |
const mostCommonValueCount = sorted_counter[0][1]; | |
// Step 2: Calculate average | |
let sum: number = 0; | |
for (const [value, count] of Object.entries(counter)) { | |
sum += | |
(typeof value === 'string' ? castToNumber(value, metric.values) : value) * | |
count; | |
} | |
const average = | |
Math.round((sum / numberOfAnnotators + Number.EPSILON) * 100) / 100; | |
// Step 3: Common patterns | |
// Step 3.a: Absolute agreement | |
if (mostCommonValueCount === numberOfAnnotators) | |
return { | |
level: AgreementLevels.ABSOLUTE_AGREEMENT, | |
value: average, | |
}; | |
// Step 3.b: Absolute disagreement/No agreement | |
if (numberOfUniqueValues === numberOfAnnotators) | |
return { | |
level: AgreementLevels.NO_AGREEMENT, | |
value: average, | |
}; | |
// Step 4: Default return | |
return { | |
level: AgreementLevels.HIGH_AGREEMENT, | |
value: average, | |
}; | |
} | |
/** | |
* Compute majority value | |
* @param metric | |
* @param counter distribution of values | |
* @param numberOfAnnotators number of annotators | |
* @returns | |
*/ | |
function computeMajority( | |
metric: Metric, | |
counter: { [key: string]: number }, | |
numberOfAnnotators: number, | |
): { level: number; value: number | string } { | |
// Step 0: Sort counter values | |
const sorted_counter = Object.entries(counter); | |
sorted_counter.sort((x, y) => { | |
return y[1] - x[1]; | |
}); | |
// Step 1: Number of unique values, most common value and its count | |
const numberOfUniqueValues = sorted_counter.length; | |
const mostCommonValue = sorted_counter[0][0]; | |
const mostCommonValueCount = sorted_counter[0][1]; | |
// Step 2: Common patterns | |
// Step 2.a: Absolute agreement | |
if (mostCommonValueCount === numberOfAnnotators) | |
return { | |
level: AgreementLevels.ABSOLUTE_AGREEMENT, | |
value: mostCommonValue, | |
}; | |
// Step 2.b: Absolute disagreement/No agreement | |
if (numberOfUniqueValues === numberOfAnnotators) | |
return { | |
level: AgreementLevels.NO_AGREEMENT, | |
value: 'Indeterminate', | |
}; | |
// Step 3: Calculate agreement levels | |
// Step 3.a: No agreement | |
// * More than half annotators selected different values | |
// OR | |
// * Less than half annotators selected same value and Top-2 most common values are greater than 1 unit apart | |
if ( | |
numberOfUniqueValues > Math.ceil(numberOfAnnotators / 2) || | |
(mostCommonValueCount < Math.ceil(numberOfAnnotators / 2) && | |
numberOfUniqueValues === Math.ceil(numberOfAnnotators / 2) && | |
Math.abs( | |
castToNumber(mostCommonValue, metric.values) - | |
castToNumber(sorted_counter[1][0], metric.values), | |
) > 1) | |
) { | |
return { | |
level: AgreementLevels.NO_AGREEMENT, | |
value: 'Indeterminate', | |
}; | |
} | |
// Step 3.b: High agreement | |
// * Maximum two unique values and those are less than 2 unit apart | |
if ( | |
numberOfUniqueValues == 2 && | |
Math.abs( | |
castToNumber(mostCommonValue, metric.values) - | |
castToNumber(sorted_counter[1][0], metric.values), | |
) < 2 | |
) { | |
return { | |
level: AgreementLevels.HIGH_AGREEMENT, | |
value: mostCommonValue, | |
}; | |
} | |
// Step 3.c: Default return | |
return { | |
level: AgreementLevels.LOW_AGREEMENT, | |
value: mostCommonValue, | |
}; | |
} | |
export function calculateAggregateValue( | |
metric: Metric, | |
entries: { [key: string]: any }, | |
) { | |
if (metric.author === 'algorithm') { | |
if (metric.aggregator) { | |
let scores: string[] | number[] = Object.values(entries).map( | |
(entry) => entry.value, | |
); | |
if (metric.aggregator === 'average') { | |
return computeAverage(metric, countBy(scores), scores.length); | |
} else { | |
return computeMajority(metric, countBy(scores), scores.length); | |
} | |
} else { | |
return { | |
level: AgreementLevels.NO_AGREEMENT, | |
value: undefined, | |
}; | |
} | |
} else { | |
if (metric.aggregator) { | |
let scores: string[] | number[] = Object.values(entries).map( | |
(entry) => entry.value, | |
); | |
if (metric.aggregator === 'average') { | |
return computeAverage(metric, countBy(scores), scores.length); | |
} else { | |
return computeMajority(metric, countBy(scores), scores.length); | |
} | |
} else { | |
return { | |
level: AgreementLevels.NO_AGREEMENT, | |
value: undefined, | |
}; | |
} | |
} | |
} | |
export function mergeAgreementObjects({ | |
source, | |
target, | |
}: { | |
source: object; | |
target: object; | |
}) { | |
if (source) { | |
Object.entries(source).forEach(([group, entry]) => { | |
for (const [key, value] of Object.entries(entry)) { | |
if (target.hasOwnProperty(group)) { | |
if (target[group].hasOwnProperty(key)) { | |
target[group][key] += value; | |
} else { | |
target[group][key] = value; | |
} | |
} else { | |
target[group] = { [key]: value }; | |
} | |
} | |
}); | |
} | |
} | |
export function bin(value: number | string, metric: Metric, n?: number) { | |
if (typeof value === 'number' && metric.type === 'numerical') { | |
if (metric.range && metric.range.length == 3) { | |
for ( | |
let idx: number = 0; | |
metric.range[0] + idx * metric.range[2] + metric.range[2] <= | |
metric.range[1]; | |
idx++ | |
) { | |
const start: number = parseFloat( | |
(metric.range[0] + idx * metric.range[2]).toFixed(2), | |
); | |
const end: number = parseFloat( | |
(metric.range[0] + idx * metric.range[2] + metric.range[2]).toFixed( | |
2, | |
), | |
); | |
if (start <= value && value <= end) { | |
return `${start}-${end}`; | |
} | |
} | |
} | |
} | |
return value; | |
} | |
export function compareMetricAggregatedValues( | |
a: { key: string | number; value: number }, | |
b: { key: string | number; value: number }, | |
metric: Metric, | |
): number { | |
if (metric.aggregator && metric.aggregator === 'average') { | |
if (typeof a.key === 'number' && typeof b.key === 'number') { | |
return a.key - b.key; | |
} else if (typeof a.key === 'string' && typeof b.key === 'string') { | |
return parseFloat(a.key) - parseFloat(b.key); | |
} else { | |
return 0; | |
} | |
} else if (metric.aggregator && metric.aggregator === 'majority') { | |
if (typeof a.key === 'string' && typeof b.key === 'string') { | |
if (a.key === 'Indeterminate' || b.key === 'Indeterminate') { | |
if (b.key === 'Indeterminate' && a.key != 'Indeterminate') { | |
return 1; | |
} else if (a.key === 'Indeterminate' && b.key != 'Indeterminate') { | |
return -1; | |
} | |
return 0; | |
} | |
const aValue = metric.values?.find((entry) => entry.value == a.key); | |
const bValue = metric.values?.find((entry) => entry.value == b.key); | |
if (aValue && bValue) { | |
// Do direct value comparison in numerical values exists | |
if ( | |
(aValue.numericValue != undefined || aValue.numericValue != null) && | |
isNumber(aValue.numericValue) && | |
(bValue.numericValue != undefined || bValue.numericValue != null) && | |
isNumber(bValue.numericValue) | |
) { | |
return aValue.numericValue - bValue.numericValue; | |
} | |
// For numerical values, do direct value comparison | |
else if (typeof a.value === 'number' && typeof b.value === 'number') { | |
return a.value - b.value; | |
} else { | |
return a.key.localeCompare(b.key); | |
} | |
} | |
// Do string comparison with non-ASCII support | |
return a.key.localeCompare(b.key); | |
} | |
// Default: Preserve same order | |
return 0; | |
} | |
return a.key > b.key ? 1 : -1; | |
} | |