Spaces:
Runtime error
Runtime error
# Copyright 2024 MIT Han Lab | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import os | |
import warnings | |
from typing import Any, Tuple | |
import torch | |
def export_onnx( | |
model: torch.nn.Module, | |
input_shape: Tuple[int], | |
export_path: str, | |
opset: int, | |
export_dtype: torch.dtype, | |
export_device: torch.device, | |
) -> None: | |
model.eval() | |
dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)} | |
dynamic_axes = { | |
"x": {0: "batch_size"}, | |
} | |
# _ = model(**dummy_input) | |
output_names = ["image_embeddings"] | |
export_dir = os.path.dirname(export_path) | |
if not os.path.exists(export_dir): | |
os.makedirs(export_dir) | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
print(f"Exporting onnx model to {export_path}...") | |
with open(export_path, "wb") as f: | |
torch.onnx.export( | |
model, | |
tuple(dummy_input.values()), | |
f, | |
export_params=True, | |
verbose=False, | |
opset_version=opset, | |
do_constant_folding=True, | |
input_names=list(dummy_input.keys()), | |
output_names=output_names, | |
dynamic_axes=dynamic_axes, | |
) | |