HoneyTian's picture
update
69ad385
raw
history blame
3.24 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import numpy as np
import torch
from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
parser.add_argument("--model_dir", default="best", type=str)
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
args = parser.parse_args()
return args
def logging_config():
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
return logger
def main():
args = get_args()
serialization_dir = Path(args.serialization_dir)
logger = logging_config()
logger.info("export models on CPU")
device = torch.device("cpu")
logger.info("prepare vocabulary, model")
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
model = WaveClassifierPretrainedModel.from_pretrained(
pretrained_model_name_or_path=args.model_dir,
num_labels=vocabulary.get_vocab_size(namespace="labels")
)
model.to(device)
model.eval()
waveform = 0 + 25 * np.random.randn(16000,)
waveform = np.array(waveform, dtype=np.int16)
waveform = waveform / (1 << 15)
waveform = torch.tensor(waveform, dtype=torch.float32)
waveform = torch.unsqueeze(waveform, dim=0)
waveform = waveform.to(device)
logger.info("export jit models")
example_inputs = (waveform,)
# trace model
trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
trace_model.save(serialization_dir / "trace_model.zip")
# quantization trace model (not work on GPU)
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
trace_quant_model.save(serialization_dir / "trace_quant_model.zip")
# script model
script_model = torch.jit.script(obj=model)
script_model.save(serialization_dir / "script_model.zip")
# quantization script model (not work on GPU)
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(serialization_dir / "script_quant_model.zip")
return
if __name__ == '__main__':
main()