Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from collections import defaultdict | |
import json | |
import logging | |
from logging.handlers import TimedRotatingFileHandler | |
import os | |
import platform | |
from pathlib import Path | |
import sys | |
import shutil | |
from typing import List | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import numpy as np | |
import torch | |
from toolbox.torch.utils.data.vocabulary import Vocabulary | |
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) | |
parser.add_argument("--model_dir", default="best", type=str) | |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str) | |
args = parser.parse_args() | |
return args | |
def logging_config(): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.DEBUG) | |
stream_handler = logging.StreamHandler() | |
stream_handler.setLevel(logging.INFO) | |
stream_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
return logger | |
def main(): | |
args = get_args() | |
serialization_dir = Path(args.serialization_dir) | |
logger = logging_config() | |
logger.info("export models on CPU") | |
device = torch.device("cpu") | |
logger.info("prepare vocabulary, model") | |
vocabulary = Vocabulary.from_files(args.vocabulary_dir) | |
model = WaveClassifierPretrainedModel.from_pretrained( | |
pretrained_model_name_or_path=args.model_dir, | |
num_labels=vocabulary.get_vocab_size(namespace="labels") | |
) | |
model.to(device) | |
model.eval() | |
waveform = 0 + 25 * np.random.randn(16000,) | |
waveform = np.array(waveform, dtype=np.int16) | |
waveform = waveform / (1 << 15) | |
waveform = torch.tensor(waveform, dtype=torch.float32) | |
waveform = torch.unsqueeze(waveform, dim=0) | |
waveform = waveform.to(device) | |
logger.info("export jit models") | |
example_inputs = (waveform,) | |
# trace model | |
trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False) | |
trace_model.save(serialization_dir / "trace_model.zip") | |
# quantization trace model (not work on GPU) | |
quantized_model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False) | |
trace_quant_model.save(serialization_dir / "trace_quant_model.zip") | |
# script model | |
script_model = torch.jit.script(obj=model) | |
script_model.save(serialization_dir / "script_model.zip") | |
# quantization script model (not work on GPU) | |
quantized_model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
script_quant_model = torch.jit.script(quantized_model) | |
script_quant_model.save(serialization_dir / "script_quant_model.zip") | |
return | |
if __name__ == '__main__': | |
main() | |