#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--model_dir", default="best", type=str) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) args = parser.parse_args() return args def logging_config(): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) return logger def main(): args = get_args() serialization_dir = Path(args.serialization_dir) logger = logging_config() logger.info("export models on CPU") device = torch.device("cpu") logger.info("prepare vocabulary, model") vocabulary = Vocabulary.from_files(args.vocabulary_dir) model = WaveClassifierPretrainedModel.from_pretrained( pretrained_model_name_or_path=args.model_dir, num_labels=vocabulary.get_vocab_size(namespace="labels") ) model.to(device) model.eval() waveform = 0 + 25 * np.random.randn(16000,) waveform = np.array(waveform, dtype=np.int16) waveform = waveform / (1 << 15) waveform = torch.tensor(waveform, dtype=torch.float32) waveform = torch.unsqueeze(waveform, dim=0) waveform = waveform.to(device) logger.info("export jit models") example_inputs = (waveform,) # trace model trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False) trace_model.save(serialization_dir / "trace_model.zip") # quantization trace model (not work on GPU) quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False) trace_quant_model.save(serialization_dir / "trace_quant_model.zip") # script model script_model = torch.jit.script(obj=model) script_model.save(serialization_dir / "script_model.zip") # quantization script model (not work on GPU) quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) script_quant_model = torch.jit.script(quantized_model) script_quant_model.save(serialization_dir / "script_quant_model.zip") return if __name__ == '__main__': main()