Spaces:
Running
Running
import json | |
from glob import glob | |
from pathlib import Path | |
import tyro | |
def raw_params_to_readable(params: int) -> str: | |
return f"{params/1e9:.1f}B" | |
def main(results_dir: Path, output_file: Path) -> None: | |
output_file.parent.mkdir(parents=True, exist_ok=True) | |
print(f"{results_dir} -> {output_file}") | |
models = {} | |
for model_dir in sorted(glob(f"{results_dir}/*/*")): | |
model_name = "/".join(model_dir.split("/")[-2:]) | |
print(f" {model_name}") | |
result_file_cand = glob(f"{model_dir}/bs1+*+steps25+results.json") | |
assert len(result_file_cand) == 1, model_name | |
results_data = json.load(open(result_file_cand[0])) | |
denosing_module_name = "unet" if "unet" in results_data["num_parameters"] else "transformer" | |
model_info = dict( | |
url=f"https://huggingface.co/{model_name}", | |
nickname=model_name.split("/")[-1].replace("-", " ").title(), | |
total_params=raw_params_to_readable(sum(results_data["num_parameters"].values())), | |
denoising_params=raw_params_to_readable(results_data["num_parameters"][denosing_module_name]), | |
resolution="NA", | |
) | |
assert model_name not in models | |
models[model_name] = model_info | |
json.dump(models, open(output_file, "w"), indent=2) | |
if __name__ == "__main__": | |
tyro.cli(main) | |