Spaces:
Runtime error
Runtime error
| # Copyright 2024 MIT Han Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import os | |
| import warnings | |
| from typing import Any, Tuple | |
| import torch | |
| def export_onnx( | |
| model: torch.nn.Module, | |
| input_shape: Tuple[int], | |
| export_path: str, | |
| opset: int, | |
| export_dtype: torch.dtype, | |
| export_device: torch.device, | |
| ) -> None: | |
| model.eval() | |
| dummy_input = {"x": torch.randn(input_shape, dtype=export_dtype, device=export_device)} | |
| dynamic_axes = { | |
| "x": {0: "batch_size"}, | |
| } | |
| # _ = model(**dummy_input) | |
| output_names = ["image_embeddings"] | |
| export_dir = os.path.dirname(export_path) | |
| if not os.path.exists(export_dir): | |
| os.makedirs(export_dir) | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| print(f"Exporting onnx model to {export_path}...") | |
| with open(export_path, "wb") as f: | |
| torch.onnx.export( | |
| model, | |
| tuple(dummy_input.values()), | |
| f, | |
| export_params=True, | |
| verbose=False, | |
| opset_version=opset, | |
| do_constant_folding=True, | |
| input_names=list(dummy_input.keys()), | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| ) | |