/** * * Copyright 2023-2025 InspectorRAGet Team * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * **/ 'use client'; import { sampleSize, isEmpty } from 'lodash'; import DOMPurify from 'dompurify'; import parse from 'html-react-parser'; import { useMemo, useState, useEffect } from 'react'; import { FilterableMultiSelect, Tag, Toggle, DataTable, TableContainer, Table, TableHead, TableRow, TableHeader, TableBody, TableCell, Pagination, } from '@carbon/react'; import { WarningAlt, WarningAltFilled } from '@carbon/icons-react'; import { Task, Model, TaskEvaluation } from '@/src/types'; import { truncate } from '@/src/utilities/strings'; import { areObjectsIntersecting } from '@/src/utilities/objects'; import Filters from '@/src/components/filters/Filters'; import classes from './PredictionsTable.module.scss'; const MAX_NUM_ROWS = 150; // =================================================================================== // COMPUTE FUNCTIONS // =================================================================================== /** * Helper function to compute evaluations table headers and rows * @param tasks eligible tasks * @param evaluations full set of evaluations * @returns */ function populateTableRows( tasks: Task[], evaluations: TaskEvaluation[], eligibleTaskIDs: Set, ) { // Step 1: Collate predictions per task const evaluationsPerTask = new Map(); evaluations.forEach((evaluation) => { if (eligibleTaskIDs.has(evaluation.taskId)) { const evaluationsForTask = evaluationsPerTask.get(evaluation.taskId); if (evaluationsForTask) { evaluationsPerTask.set(evaluation.taskId, [ ...evaluationsForTask, evaluation, ]); } else { evaluationsPerTask.set(evaluation.taskId, [evaluation]); } } }); // Step 2: Formulate rows const rows: { [key: string]: string }[] = []; tasks.forEach((task) => { if (eligibleTaskIDs.has(task.taskId)) { // Step 2.a: Add query string const row = { id: task.taskId, task: task.taskId }; if (typeof task.input === 'string') { row['task'] = truncate(task.input, 80); } else if ( Array.isArray(task.input) && task.input[task.input.length - 1].hasOwnProperty('text') && task.input[task.input.length - 1]['text'] ) { row['task'] = truncate(task.input[task.input.length - 1]['text'], 80); } else if ( Array.isArray(task.input) && task.input[task.input.length - 1].hasOwnProperty('role') && (task.input[task.input.length - 1]['role'] === 'system' || task.input[task.input.length - 1]['role'] === 'developer' || task.input[task.input.length - 1]['role'] === 'user' || task.input[task.input.length - 1]['role'] === 'assistant') && task.input[task.input.length - 1].hasOwnProperty('content') && task.input[task.input.length - 1]['content'] ) { row['task'] = truncate( task.input[task.input.length - 1]['content'], 80, ); } // Step 2.b: Add first target, if present if (task.targets && !isEmpty(task.targets)) { row['targets'] = task.targets .map((target) => [target.text]) .filter((entry) => entry !== undefined); } // Step 3.b: Add model responses const taskEvaluations = evaluationsPerTask.get(task.taskId); if (taskEvaluations) { taskEvaluations.forEach((evaluation) => { row[evaluation.modelId] = evaluation.modelResponse; }); } // Step 2.c: Add formulated row rows.push(row); } }); // Step 3: Return return rows; } // =================================================================================== // MAIN FUNCTION // =================================================================================== export default function PredictionsTable({ tasks, models, evaluations, filters, }: { tasks: Task[]; models: Model[]; evaluations: TaskEvaluation[]; filters: { [key: string]: string[] }; }) { // Step 1: Initialize state and necessary variables const [selectedModels, setSelectedModels] = useState(models); const [showTargets, setShowTargets] = useState(true); const [showWarning, setShowWarning] = useState(false); const [page, setPage] = useState(1); const [pageSize, setPageSize] = useState(10); const [selectedFilters, setSelectedFilters] = useState<{ [key: string]: string[]; }>({}); const [visibleRows, setVisibleRows] = useState<{ [key: string]: string }[]>( [], ); // Step 2: Run effects // Step 2.a: Identify eligible task IDs based on selected filters const eligibleTaskIDs = useMemo(() => { if (!isEmpty(selectedFilters)) { const taskIds: Set = new Set(); tasks.forEach((task) => { if (areObjectsIntersecting(selectedFilters, task)) { taskIds.add(task.taskId); } }); return taskIds; } else { return new Set(tasks.map((task) => task.taskId)); } }, [tasks, selectedFilters]); // Step 2.b: Populate table rows const rows = useMemo(() => { const tableRows = populateTableRows(tasks, evaluations, eligibleTaskIDs); if (tableRows.length > MAX_NUM_ROWS) { // Add warning to indicate that only limited rows are shown, if not visible already if (!showWarning) { setShowWarning(true); } // Limit number of rows return sampleSize(tableRows, MAX_NUM_ROWS); } else { // Remove previsouly set warning, if necessary if (showWarning) { setShowWarning(false); } return tableRows; } }, [tasks, evaluations, showWarning, eligibleTaskIDs]); // Step 2.c: Adjust headers based on selectedModels and show target flat const headers = useMemo(() => { return [ { key: 'task', header: 'Task', }, showTargets ? { key: 'targets', header: 'Targets', } : null, ...selectedModels.map((model) => { return { key: model.modelId, header: `${model.name} prediction` }; }), ].filter(Boolean); }, [showTargets, selectedModels]); // Step 2.d: Set visble rows useEffect(() => { // Set visible rows setVisibleRows( rows.slice((page - 1) * pageSize, (page - 1) * pageSize + pageSize), ); }, [rows, page, pageSize]); // Step 3: Render return ( <> {headers && rows && (
item.name} onChange={(event) => { setSelectedModels(event.selectedItems); }} invalid={selectedModels.length === 0} invalidText={'You must select a model to view predictions.'} >
{selectedModels.map((model) => { return ( {model.name} ); })}
{ setShowTargets(!showTargets); }} />
{!isEmpty(filters) ? ( ) : null} {showWarning ? (
{`Only showing predictions for ${MAX_NUM_ROWS} out of ${eligibleTaskIDs.size} tasks`}
) : null} {eligibleTaskIDs.size ? ( <> {({ rows, headers, getTableProps, getHeaderProps, getRowProps, }) => ( {headers.map((header, index) => ( {header.header} ))} {rows.map((row, index) => ( {row.cells.map((cell) => ( {cell.info.header === 'targets' && cell.value ? cell.value.length > 1 ? cell.value.map( (targetText, targetIdx) => ( <> Target {targetIdx + 1} :  {parse( DOMPurify.sanitize(targetText), )}
), ) : parse(DOMPurify.sanitize(cell.value[0])) : parse(DOMPurify.sanitize(cell.value))} ))} ))}
)}
{ // Step 1: Update page size setPageSize(event.pageSize); // Step 2: Update page setPage(event.page); }} > ) : (
{`No matching tasks found. ${!isEmpty(selectedFilters) ? 'Please try again by removing one or more additional filters.' : ''}`}
)}
)} ); }