leaderboard / benchmark /diffusion /text-to-image /scripts /aggregate_leaderboard_models.py
Jae-Won Chung
Updated diffusion benchmark and data
c97bae1
raw
history blame
1.35 kB
import json
from glob import glob
from pathlib import Path
import tyro
def raw_params_to_readable(params: int) -> str:
return f"{params/1e9:.1f}B"
def main(results_dir: Path, output_file: Path) -> None:
output_file.parent.mkdir(parents=True, exist_ok=True)
print(f"{results_dir} -> {output_file}")
models = {}
for model_dir in sorted(glob(f"{results_dir}/*/*")):
model_name = "/".join(model_dir.split("/")[-2:])
print(f" {model_name}")
result_file_cand = glob(f"{model_dir}/bs1+*+steps25+results.json")
assert len(result_file_cand) == 1, model_name
results_data = json.load(open(result_file_cand[0]))
denosing_module_name = "unet" if "unet" in results_data["num_parameters"] else "transformer"
model_info = dict(
url=f"https://huggingface.co/{model_name}",
nickname=model_name.split("/")[-1].replace("-", " ").title(),
total_params=raw_params_to_readable(sum(results_data["num_parameters"].values())),
denoising_params=raw_params_to_readable(results_data["num_parameters"][denosing_module_name]),
resolution="NA",
)
assert model_name not in models
models[model_name] = model_info
json.dump(models, open(output_file, "w"), indent=2)
if __name__ == "__main__":
tyro.cli(main)