demo / frontend /src /components /BenchmarkGenerator.jsx
tfrere's picture
first commit
970eef1
raw
history blame
12.7 kB
import React, { useState, useEffect, useRef } from "react";
import { Box, Typography, CircularProgress, Alert, Paper } from "@mui/material";
import PlayArrowIcon from "@mui/icons-material/PlayArrow";
import AccessTimeIcon from "@mui/icons-material/AccessTime";
import LogDisplay from "./LogDisplay";
// Define all benchmark steps in sequence
const BENCHMARK_STEPS = [
"ingestion",
"upload_ingest_to_hub",
"summarization",
"chunking",
"single_shot_question_generation",
"multi_hop_question_generation",
"lighteval",
];
// Step labels for display (more user-friendly names)
const STEP_LABELS = {
ingestion: "Ingestion",
upload_ingest_to_hub: "Upload to Hub",
summarization: "Summarization",
chunking: "Chunking",
single_shot_question_generation: "Single-shot QG",
multi_hop_question_generation: "Multi-hop QG",
lighteval: "LightEval",
};
/**
* Component to handle benchmark generation and display logs
*
* @param {Object} props - Component props
* @param {string} props.sessionId - The session ID for the uploaded file
* @param {Function} props.onComplete - Function to call when generation is complete
* @returns {JSX.Element} Benchmark generator component
*/
const BenchmarkGenerator = ({ sessionId, onComplete }) => {
const [generating, setGenerating] = useState(false);
const [generationComplete, setGenerationComplete] = useState(false);
const [generationLogs, setGenerationLogs] = useState([]);
const [error, setError] = useState(null);
const [currentPhase, setCurrentPhase] = useState("initializing");
const [completedSteps, setCompletedSteps] = useState([]);
const [activeStep, setActiveStep] = useState(0);
const [elapsedTime, setElapsedTime] = useState(0);
// Reference to keep track of the polling interval
const pollingIntervalRef = useRef(null);
// Reference to keep track of the timer interval
const timerIntervalRef = useRef(null);
// Reference for starting time
const startTimeRef = useRef(null);
// Start generation on component mount
useEffect(() => {
// Set start time
startTimeRef.current = Date.now();
// Start timer
timerIntervalRef.current = setInterval(() => {
const timeElapsed = Math.floor(
(Date.now() - startTimeRef.current) / 1000
);
setElapsedTime(timeElapsed);
}, 1000);
generateBenchmark();
// Clean up the polling interval and timer when the component unmounts
return () => {
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
}
if (timerIntervalRef.current) {
clearInterval(timerIntervalRef.current);
}
};
}, []);
// Determine the current phase and completed steps based on logs
useEffect(() => {
if (generationLogs.length === 0) return;
// Check all logs for completed stages
const newCompletedSteps = [...completedSteps];
let newActiveStep = activeStep;
generationLogs.forEach((log) => {
const match = log.match(/\[SUCCESS\] Stage completed: (\w+)/);
if (match && match[1]) {
const completedStep = match[1].trim();
if (
BENCHMARK_STEPS.includes(completedStep) &&
!newCompletedSteps.includes(completedStep)
) {
newCompletedSteps.push(completedStep);
// Set active step to the index of the next step
const stepIndex = BENCHMARK_STEPS.indexOf(completedStep);
if (stepIndex >= 0 && stepIndex + 1 > newActiveStep) {
newActiveStep = stepIndex + 1;
if (newActiveStep >= BENCHMARK_STEPS.length) {
newActiveStep = BENCHMARK_STEPS.length;
}
}
}
}
});
// Update state if there are new completed steps
if (newCompletedSteps.length > completedSteps.length) {
setCompletedSteps(newCompletedSteps);
setActiveStep(newActiveStep);
}
// Check the latest logs to determine the current phase
const recentLogs = generationLogs.slice(-10); // Check more logs
// Detect completion conditions
const isComplete =
recentLogs.some((log) =>
log.includes("[SUCCESS] Ingestion process completed successfully")
) ||
recentLogs.some((log) =>
log.includes(
"[SUCCESS] Configuration and ingestion completed successfully"
)
) ||
completedSteps.includes("lighteval") ||
newCompletedSteps.includes("lighteval");
if (isComplete) {
setCurrentPhase("complete");
setGenerationComplete(true);
// Stop polling when benchmark is complete
if (pollingIntervalRef.current) {
clearInterval(pollingIntervalRef.current);
}
// Notify parent component that generation is complete
if (onComplete) {
console.log("Notifying parent that generation is complete");
onComplete({
success: true,
sessionId,
logs: generationLogs,
});
}
} else if (
recentLogs.some((log) => log.includes("starting benchmark creation"))
) {
setCurrentPhase("benchmarking");
} else if (
recentLogs.some((log) => log.includes("Generating base configuration"))
) {
setCurrentPhase("configuring");
}
}, [generationLogs, completedSteps, activeStep, sessionId, onComplete]);
const generateBenchmark = async () => {
if (!sessionId) {
setError("Missing session ID");
return;
}
setGenerating(true);
setGenerationLogs([]);
setError(null);
setCurrentPhase("initializing");
setCompletedSteps([]);
setActiveStep(0);
try {
// Call the API to generate the benchmark
const response = await fetch("http://localhost:3001/generate-benchmark", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
session_id: sessionId,
}),
});
const result = await response.json();
if (response.ok) {
setGenerationLogs(result.logs || []);
// D'abord, on commence par interroger les logs de configuration
const pollConfigLogs = async () => {
try {
// Call the API to get the config logs
const configLogsResponse = await fetch(
`http://localhost:3001/config-logs/${sessionId}`
);
if (configLogsResponse.ok) {
const configLogsResult = await configLogsResponse.json();
// Update logs if there are new ones
if (
configLogsResult.logs &&
configLogsResult.logs.length > generationLogs.length
) {
setGenerationLogs(configLogsResult.logs);
}
// If config task is completed, switch to polling benchmark logs
if (configLogsResult.is_completed) {
// Attendre un court instant pour permettre au serveur de démarrer le benchmark
setTimeout(() => {
console.log(
"Configuration completed, switching to benchmark polling"
);
clearInterval(configPollingIntervalRef.current);
pollBenchmarkLogs();
}, 1000);
}
}
} catch (error) {
console.log("Error polling for config logs:", error);
// Don't stop polling on network errors
}
};
// Fonction pour interroger les logs du benchmark
const pollBenchmarkLogs = async () => {
// Set up polling for benchmark logs
pollingIntervalRef.current = setInterval(async () => {
// Check if we already completed
if (generationComplete) {
clearInterval(pollingIntervalRef.current);
return;
}
try {
// Call the API to get the latest benchmark logs
const logsResponse = await fetch(
`http://localhost:3001/benchmark-logs/${sessionId}`
);
if (logsResponse.ok) {
const logsResult = await logsResponse.json();
// Update logs if there are new ones
if (
logsResult.logs &&
logsResult.logs.length > generationLogs.length
) {
setGenerationLogs(logsResult.logs);
}
// Check if the task is completed
if (logsResult.is_completed) {
setGenerationComplete(true);
clearInterval(pollingIntervalRef.current);
// Notification is now handled in the useEffect above
}
}
} catch (error) {
console.log("Error polling for benchmark logs:", error);
// Don't stop polling on network errors
}
}, 3000); // Poll every 3 seconds
};
// Démarrer le polling des logs de configuration
const configPollingIntervalRef = { current: null };
configPollingIntervalRef.current = setInterval(pollConfigLogs, 1000); // Poll config logs more frequently (every second)
} else {
// Handle error
setGenerationLogs([`Error: ${result.error || "Unknown error"}`]);
setError(result.error || "Benchmark generation failed");
}
} catch (error) {
console.error("Error generating benchmark:", error);
setGenerationLogs([`Error: ${error.message || "Unknown error"}`]);
setError("Server connection error");
} finally {
setGenerating(false);
}
};
// Get title based on current phase
const getPhaseTitle = () => {
switch (currentPhase) {
case "initializing":
return "Benchmark generation...";
case "configuring":
return "Generating configuration file...";
case "benchmarking":
return "Creating benchmark...";
case "complete":
return "Benchmark generated successfully!";
default:
return "Processing...";
}
};
// Get the current step information for display
const getCurrentStepInfo = () => {
const totalSteps = BENCHMARK_STEPS.length;
const currentStepIndex = activeStep;
// If there's no active step yet
if (currentStepIndex === 0 && completedSteps.length === 0) {
return `Starting... (0%)`;
}
// If all steps are completed
if (currentStepIndex >= totalSteps) {
return `Complete (100%)`;
}
// Calculate percentage
const percentage = Math.round((currentStepIndex / totalSteps) * 100);
// Get current step name
const currentStepName =
STEP_LABELS[BENCHMARK_STEPS[currentStepIndex]] || "Processing";
return `${currentStepName} (${percentage}%)`;
};
// Format elapsed time in HH:MM:SS
const formatElapsedTime = () => {
const hours = Math.floor(elapsedTime / 3600);
const minutes = Math.floor((elapsedTime % 3600) / 60);
const seconds = elapsedTime % 60;
return [
hours.toString().padStart(2, "0"),
minutes.toString().padStart(2, "0"),
seconds.toString().padStart(2, "0"),
].join(":");
};
// If complete, stop the timer
useEffect(() => {
if (generationComplete && timerIntervalRef.current) {
clearInterval(timerIntervalRef.current);
}
}, [generationComplete]);
return (
<Paper
elevation={3}
sx={{
p: 4,
mt: 3,
mb: 3,
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
minHeight: 200,
}}
>
{error ? (
<Alert severity="error" sx={{ width: "100%" }}>
{error}
</Alert>
) : (
<>
<CircularProgress size={60} sx={{ mb: 2 }} />
<Typography variant="h6" component="div" gutterBottom>
{getPhaseTitle()}
</Typography>
{/* Step progress indicator */}
<Typography variant="body1" color="text.secondary">
{getCurrentStepInfo()}
</Typography>
{/* Timer display */}
<Box
sx={{
display: "flex",
alignItems: "center",
mt: 1,
color: "text.secondary",
}}
>
<Typography variant="body2" sx={{ opacity: 0.5 }}>
{formatElapsedTime()}
</Typography>
</Box>
</>
)}
{/* Use the LogDisplay component */}
<LogDisplay logs={generationLogs} height={300} />
</Paper>
);
};
export default BenchmarkGenerator;