Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import yaml | |
import pandas as pd | |
import json | |
from datetime import datetime | |
from loguru import logger | |
def multirun_artifact_producer(base_path: str, output_path: str): | |
"""Aggregate data from the latest run's csv folder and save to a JSON file.""" | |
# Find the latest top-level run folder | |
latest_folder = max(glob.glob(os.path.join(base_path, "*")), key=os.path.getmtime) | |
if not os.path.isdir(latest_folder): | |
logger.error("No valid run folders found!") | |
return | |
# Initialize JSON structure | |
output_data = {} | |
# Process each sub-run directory within the latest run folder | |
for run_dir in os.listdir(latest_folder): | |
run_path = os.path.join(latest_folder, run_dir) | |
if os.path.isdir(run_path): | |
# Look for the latest folder in the csv subdirectory | |
csv_base_path = os.path.join(run_path, "csv") | |
if not os.path.isdir(csv_base_path): | |
logger.warning(f"No csv directory found in {run_path}. Skipping.") | |
continue | |
# Find the latest version folder in csv | |
latest_csv_folder = max( | |
glob.glob(os.path.join(csv_base_path, "version_*")), | |
key=os.path.getmtime, | |
) | |
if not os.path.isdir(latest_csv_folder): | |
logger.warning( | |
f"No valid version folder found in {csv_base_path}. Skipping." | |
) | |
continue | |
# Paths to files in the latest csv version folder | |
hparams_path = os.path.join(latest_csv_folder, "hparams.yaml") | |
metrics_path = os.path.join(latest_csv_folder, "metrics.csv") | |
# Check if necessary files exist | |
if not os.path.isfile(hparams_path) or not os.path.isfile(metrics_path): | |
logger.warning( | |
f"Missing hparams.yaml or metrics.csv in {latest_csv_folder}. Skipping." | |
) | |
continue | |
# Read hparams.yaml | |
with open(hparams_path, "r") as file: | |
hparams = yaml.safe_load(file) | |
# Read metrics.csv and calculate averages | |
metrics_df = pd.read_csv(metrics_path) | |
avg_train_acc = metrics_df["train_acc"].dropna().mean() | |
avg_val_acc = metrics_df["val_acc"].dropna().mean() | |
# Create JSON structure for this run | |
output_data[f"run_{run_dir}"] = { | |
"hparams": hparams, | |
"metrics": {"avg_train_acc": avg_train_acc, "avg_val_acc": avg_val_acc}, | |
} | |
# Save aggregated data to JSON | |
os.makedirs(output_path, exist_ok=True) | |
output_file = os.path.join( | |
output_path, f"aggregated_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
) | |
logger.info(f"Saving aggregated data to {output_file}") | |
with open(output_file, "w") as json_file: | |
json.dump(output_data, json_file, indent=4) | |
if __name__ == "__main__": | |
# Paths | |
base_path = "./logs/train/runs" | |
output_path = "./artifacts" | |
multirun_artifact_producer(base_path, output_path) | |