File size: 3,239 Bytes
69ad385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import sys
import shutil
from typing import List

pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))

import numpy as np
import torch

from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
    parser.add_argument("--model_dir", default="best", type=str)

    parser.add_argument("--serialization_dir", default="serialization_dir", type=str)

    args = parser.parse_args()
    return args


def logging_config():
    fmt = "%(asctime)s - %(name)s - %(levelname)s  %(filename)s:%(lineno)d >  %(message)s"

    logging.basicConfig(format=fmt,
                        datefmt="%m/%d/%Y %H:%M:%S",
                        level=logging.DEBUG)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    stream_handler.setFormatter(logging.Formatter(fmt))

    logger = logging.getLogger(__name__)

    return logger


def main():
    args = get_args()

    serialization_dir = Path(args.serialization_dir)

    logger = logging_config()

    logger.info("export models on CPU")
    device = torch.device("cpu")

    logger.info("prepare vocabulary, model")
    vocabulary = Vocabulary.from_files(args.vocabulary_dir)

    model = WaveClassifierPretrainedModel.from_pretrained(
        pretrained_model_name_or_path=args.model_dir,
        num_labels=vocabulary.get_vocab_size(namespace="labels")
    )
    model.to(device)
    model.eval()

    waveform = 0 + 25 * np.random.randn(16000,)
    waveform = np.array(waveform, dtype=np.int16)
    waveform = waveform / (1 << 15)
    waveform = torch.tensor(waveform, dtype=torch.float32)
    waveform = torch.unsqueeze(waveform, dim=0)
    waveform = waveform.to(device)

    logger.info("export jit models")
    example_inputs = (waveform,)

    # trace model
    trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
    trace_model.save(serialization_dir / "trace_model.zip")

    # quantization trace model (not work on GPU)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
    trace_quant_model.save(serialization_dir / "trace_quant_model.zip")

    # script model
    script_model = torch.jit.script(obj=model)
    script_model.save(serialization_dir / "script_model.zip")

    # quantization script model (not work on GPU)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    script_quant_model = torch.jit.script(quantized_model)
    script_quant_model.save(serialization_dir / "script_quant_model.zip")
    return


if __name__ == '__main__':
    main()