Soutrik
added gitignore
3e4ba8b
import os
import glob
import yaml
import pandas as pd
import json
from datetime import datetime
from loguru import logger
def multirun_artifact_producer(base_path: str, output_path: str):
"""Aggregate data from the latest run's csv folder and save to a JSON file."""
# Find the latest top-level run folder
latest_folder = max(glob.glob(os.path.join(base_path, "*")), key=os.path.getmtime)
if not os.path.isdir(latest_folder):
logger.error("No valid run folders found!")
return
# Initialize JSON structure
output_data = {}
# Process each sub-run directory within the latest run folder
for run_dir in os.listdir(latest_folder):
run_path = os.path.join(latest_folder, run_dir)
if os.path.isdir(run_path):
# Look for the latest folder in the csv subdirectory
csv_base_path = os.path.join(run_path, "csv")
if not os.path.isdir(csv_base_path):
logger.warning(f"No csv directory found in {run_path}. Skipping.")
continue
# Find the latest version folder in csv
latest_csv_folder = max(
glob.glob(os.path.join(csv_base_path, "version_*")),
key=os.path.getmtime,
)
if not os.path.isdir(latest_csv_folder):
logger.warning(
f"No valid version folder found in {csv_base_path}. Skipping."
)
continue
# Paths to files in the latest csv version folder
hparams_path = os.path.join(latest_csv_folder, "hparams.yaml")
metrics_path = os.path.join(latest_csv_folder, "metrics.csv")
# Check if necessary files exist
if not os.path.isfile(hparams_path) or not os.path.isfile(metrics_path):
logger.warning(
f"Missing hparams.yaml or metrics.csv in {latest_csv_folder}. Skipping."
)
continue
# Read hparams.yaml
with open(hparams_path, "r") as file:
hparams = yaml.safe_load(file)
# Read metrics.csv and calculate averages
metrics_df = pd.read_csv(metrics_path)
avg_train_acc = metrics_df["train_acc"].dropna().mean()
avg_val_acc = metrics_df["val_acc"].dropna().mean()
# Create JSON structure for this run
output_data[f"run_{run_dir}"] = {
"hparams": hparams,
"metrics": {"avg_train_acc": avg_train_acc, "avg_val_acc": avg_val_acc},
}
# Save aggregated data to JSON
os.makedirs(output_path, exist_ok=True)
output_file = os.path.join(
output_path, f"aggregated_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
logger.info(f"Saving aggregated data to {output_file}")
with open(output_file, "w") as json_file:
json.dump(output_data, json_file, indent=4)
if __name__ == "__main__":
# Paths
base_path = "./logs/train/runs"
output_path = "./artifacts"
multirun_artifact_producer(base_path, output_path)