File size: 2,320 Bytes
c09c4de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import onnx
import pathlib
import glob
import os
import json
import tempfile
import shutil
import typer
from typing import Annotated
# To use:
# pip install onnx typer
# python convert_to_external_data --base-path "/path/to/directory"
def convert(model_path, save_path):
model = onnx.load(model_path)
external_data_name = f"{pathlib.Path(model_path).stem}.onnx_data"
# Create the new model in a temporary directory and copy all it's content back to save_path
# Doing this because if save_path is same as model_path & we directly write to model_path
# onnx will append to the external data path which would make it grow more than expected.
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_model_path = os.path.join(tmp_dir, os.path.basename(model_path))
onnx.save_model(
model, tmp_model_path, save_as_external_data=True, location=external_data_name
)
file_names = os.listdir(tmp_dir)
target_dir = str(pathlib.Path(save_path).parent)
os.makedirs(target_dir, exist_ok=True)
for file_name in file_names:
shutil.copy2(os.path.join(tmp_dir, file_name), target_dir)
def main(base_path: Annotated[str, typer.Option()]):
"""
This will convert recursively all onnx models in that directory to one with external data format.
"""
# Convert all
for model_path in glob.glob(
os.path.join(base_path, "**/*.onnx"),
recursive=True,
):
print("Converting", model_path)
convert(model_path, model_path)
# Find all config.json and add enable use_external_data_format
for config_path in glob.glob(
os.path.join(base_path, "**/config.json"),
recursive=True,
):
print("Modifying", config_path)
# Load the JSON file
with open(config_path, "r") as infile:
config_data = json.load(infile)
config_data["transformers.js_config"] = config_data.get(
"transformers.js_config", {}
)
config_data["transformers.js_config"]["use_external_data_format"] = True
# Save the JSON file with additional config
with open(config_path, "w") as outfile:
json.dump(config_data, outfile, indent=4, ensure_ascii=False)
if __name__ == "__main__":
typer.run(main)
|