File size: 3,973 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0


import glob
import os
import unittest
from os.path import join

import pytest

from onnx import ModelProto, hub


@pytest.mark.skipif(

    "TEST_HUB" not in os.environ or not os.environ["TEST_HUB"],

    reason="Conserving Git LFS quota",

)
class TestModelHub(unittest.TestCase):
    def setUp(self) -> None:
        self.name = "MNIST"
        self.repo = "onnx/models:main"
        self.opset = 7

    def test_force_reload(self) -> None:
        model = hub.load(self.name, self.repo, force_reload=True)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(
            glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True)
        )
        self.assertGreaterEqual(len(cached_files), 1)

    def test_listing_models(self) -> None:
        model_info_list_1 = hub.list_models(self.repo, model="mnist", tags=["vision"])
        model_info_list_2 = hub.list_models(self.repo, tags=["vision"])
        model_info_list_3 = hub.list_models(self.repo)

        self.assertGreater(len(model_info_list_1), 1)
        self.assertGreater(len(model_info_list_2), len(model_info_list_1))
        self.assertGreater(len(model_info_list_3), len(model_info_list_2))

    def test_basic_usage(self) -> None:
        model = hub.load(self.name, self.repo)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(
            glob.glob(join(hub.get_dir(), "**", "*.onnx"), recursive=True)
        )
        self.assertGreaterEqual(len(cached_files), 1)

    def test_custom_cache(self) -> None:
        old_cache = hub.get_dir()
        new_cache = join(old_cache, "custom")
        hub.set_dir(new_cache)

        model = hub.load(self.name, self.repo)
        self.assertIsInstance(model, ModelProto)

        cached_files = list(glob.glob(join(new_cache, "**", "*.onnx"), recursive=True))
        self.assertGreaterEqual(len(cached_files), 1)

        hub.set_dir(old_cache)

    def test_download_with_opset(self) -> None:
        model = hub.load(self.name, self.repo, opset=8)
        self.assertIsInstance(model, ModelProto)

    def test_opset_error(self) -> None:
        self.assertRaises(
            AssertionError, lambda: hub.load(self.name, self.repo, opset=-1)
        )

    def test_manifest_not_found(self) -> None:
        self.assertRaises(
            AssertionError,
            lambda: hub.load(self.name, "onnx/models:unknown", silent=True),
        )

    def test_verify_repo_ref(self) -> None:
        # Not trusted repo:
        verified = hub._verify_repo_ref("mhamilton723/models")
        self.assertFalse(verified)

        # Not trusted repo:
        verified = hub._verify_repo_ref("onnx/models:unknown")
        self.assertFalse(verified)

        # Trusted repo:
        verified = hub._verify_repo_ref(self.repo)
        self.assertTrue(verified)

    def test_get_model_info(self) -> None:
        hub.get_model_info("mnist", self.repo, opset=8)
        hub.get_model_info("mnist", self.repo)
        self.assertRaises(
            AssertionError, lambda: hub.get_model_info("mnist", self.repo, opset=-1)
        )

    def test_download_model_with_test_data(self) -> None:
        directory = hub.download_model_with_test_data("mnist")
        files = os.listdir(directory)
        self.assertIsInstance(directory, str)
        self.assertIn(member="model.onnx", container=files, msg="Onnx model not found")
        self.assertIn(
            member="test_data_set_0", container=files, msg="Test data not found"
        )

    def test_model_with_preprocessing(self) -> None:
        model = hub.load_composite_model(
            "ResNet50-fp32", preprocessing_model="ResNet-preproc"
        )
        self.assertIsInstance(model, ModelProto)


if __name__ == "__main__":
    unittest.main()