Soutrik
added gitignore
3e4ba8b
raw
history blame
3.11 kB
import os
import glob
import yaml
import pandas as pd
import json
from datetime import datetime
from loguru import logger
def multirun_artifact_producer(base_path: str, output_path: str):
"""Aggregate data from the latest run's csv folder and save to a JSON file."""
# Find the latest top-level run folder
latest_folder = max(glob.glob(os.path.join(base_path, "*")), key=os.path.getmtime)
if not os.path.isdir(latest_folder):
logger.error("No valid run folders found!")
return
# Initialize JSON structure
output_data = {}
# Process each sub-run directory within the latest run folder
for run_dir in os.listdir(latest_folder):
run_path = os.path.join(latest_folder, run_dir)
if os.path.isdir(run_path):
# Look for the latest folder in the csv subdirectory
csv_base_path = os.path.join(run_path, "csv")
if not os.path.isdir(csv_base_path):
logger.warning(f"No csv directory found in {run_path}. Skipping.")
continue
# Find the latest version folder in csv
latest_csv_folder = max(
glob.glob(os.path.join(csv_base_path, "version_*")),
key=os.path.getmtime,
)
if not os.path.isdir(latest_csv_folder):
logger.warning(
f"No valid version folder found in {csv_base_path}. Skipping."
)
continue
# Paths to files in the latest csv version folder
hparams_path = os.path.join(latest_csv_folder, "hparams.yaml")
metrics_path = os.path.join(latest_csv_folder, "metrics.csv")
# Check if necessary files exist
if not os.path.isfile(hparams_path) or not os.path.isfile(metrics_path):
logger.warning(
f"Missing hparams.yaml or metrics.csv in {latest_csv_folder}. Skipping."
)
continue
# Read hparams.yaml
with open(hparams_path, "r") as file:
hparams = yaml.safe_load(file)
# Read metrics.csv and calculate averages
metrics_df = pd.read_csv(metrics_path)
avg_train_acc = metrics_df["train_acc"].dropna().mean()
avg_val_acc = metrics_df["val_acc"].dropna().mean()
# Create JSON structure for this run
output_data[f"run_{run_dir}"] = {
"hparams": hparams,
"metrics": {"avg_train_acc": avg_train_acc, "avg_val_acc": avg_val_acc},
}
# Save aggregated data to JSON
os.makedirs(output_path, exist_ok=True)
output_file = os.path.join(
output_path, f"aggregated_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
logger.info(f"Saving aggregated data to {output_file}")
with open(output_file, "w") as json_file:
json.dump(output_data, json_file, indent=4)
if __name__ == "__main__":
# Paths
base_path = "./logs/train/runs"
output_path = "./artifacts"
multirun_artifact_producer(base_path, output_path)