Spaces:
Running
Running
File size: 1,350 Bytes
b10121d c97bae1 b10121d c97bae1 b10121d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import json
from glob import glob
from pathlib import Path
import tyro
def raw_params_to_readable(params: int) -> str:
return f"{params/1e9:.1f}B"
def main(results_dir: Path, output_file: Path) -> None:
output_file.parent.mkdir(parents=True, exist_ok=True)
print(f"{results_dir} -> {output_file}")
models = {}
for model_dir in sorted(glob(f"{results_dir}/*/*")):
model_name = "/".join(model_dir.split("/")[-2:])
print(f" {model_name}")
result_file_cand = glob(f"{model_dir}/bs1+*+steps25+results.json")
assert len(result_file_cand) == 1, model_name
results_data = json.load(open(result_file_cand[0]))
denosing_module_name = "unet" if "unet" in results_data["num_parameters"] else "transformer"
model_info = dict(
url=f"https://huggingface.co/{model_name}",
nickname=model_name.split("/")[-1].replace("-", " ").title(),
total_params=raw_params_to_readable(sum(results_data["num_parameters"].values())),
denoising_params=raw_params_to_readable(results_data["num_parameters"][denosing_module_name]),
resolution="NA",
)
assert model_name not in models
models[model_name] = model_info
json.dump(models, open(output_file, "w"), indent=2)
if __name__ == "__main__":
tyro.cli(main)
|