You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
4.7 KiB
159 lines
4.7 KiB
import argparse
|
|
import random
|
|
|
|
import jsonlines
|
|
import numpy as np
|
|
import paddle
|
|
from paddleslim.quant import quant_post_static
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_dev_dataloader
|
|
from paddlespeech.t2s.utils import str2bool
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Paddle Slim Static with acoustic model & vocoder.")
|
|
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=1, help="Minibatch size.")
|
|
parser.add_argument("--batch_num", type=int, default=1, help="Batch number")
|
|
parser.add_argument(
|
|
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
|
# model_path save_path
|
|
parser.add_argument(
|
|
"--inference_dir", type=str, help="dir to save inference models")
|
|
parser.add_argument(
|
|
'--model_name',
|
|
type=str,
|
|
default='fastspeech2_csmsc',
|
|
choices=[
|
|
'speedyspeech_csmsc',
|
|
'fastspeech2_csmsc',
|
|
'fastspeech2_aishell3',
|
|
'fastspeech2_ljspeech',
|
|
'fastspeech2_vctk',
|
|
'fastspeech2_mix',
|
|
'pwgan_csmsc',
|
|
'pwgan_aishell3',
|
|
'pwgan_ljspeech',
|
|
'pwgan_vctk',
|
|
'mb_melgan_csmsc',
|
|
'hifigan_csmsc',
|
|
'hifigan_aishell3',
|
|
'hifigan_ljspeech',
|
|
'hifigan_vctk',
|
|
'pwgan_opencpop',
|
|
'hifigan_opencpop',
|
|
],
|
|
help='Choose model type of tts task.')
|
|
|
|
parser.add_argument(
|
|
"--algo", type=str, default='avg', help="calibration algorithm.")
|
|
parser.add_argument(
|
|
"--round_type",
|
|
type=str,
|
|
default='round',
|
|
help="The method of converting the quantized weights.")
|
|
parser.add_argument(
|
|
"--hist_percent",
|
|
type=float,
|
|
default=0.9999,
|
|
help="The percentile of algo:hist.")
|
|
parser.add_argument(
|
|
"--is_full_quantize",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Whether is full quantization or not.")
|
|
parser.add_argument(
|
|
"--bias_correction",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Whether to use bias correction.")
|
|
parser.add_argument(
|
|
"--ce_test", type=str2bool, default=False, help="Whether to CE test.")
|
|
parser.add_argument(
|
|
"--onnx_format",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Whether to export the quantized model with format of ONNX.")
|
|
parser.add_argument(
|
|
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
|
parser.add_argument(
|
|
"--speaker-dict",
|
|
type=str,
|
|
default=None,
|
|
help="speaker id map file for multiple speaker model.")
|
|
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
|
parser.add_argument(
|
|
"--quantizable_op_type",
|
|
type=list,
|
|
nargs='+',
|
|
default=[
|
|
"conv2d_transpose", "conv2d", "depthwise_conv2d", "mul", "matmul",
|
|
"matmul_v2"
|
|
],
|
|
help="The list of op types that will be quantized.")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def quantize(args):
|
|
shuffle = True
|
|
if args.ce_test:
|
|
# set seed
|
|
seed = 111
|
|
np.random.seed(seed)
|
|
paddle.seed(seed)
|
|
random.seed(seed)
|
|
shuffle = False
|
|
|
|
place = paddle.CUDAPlace(0) if args.ngpu > 0 else paddle.CPUPlace()
|
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
|
dev_metadata = list(reader)
|
|
|
|
dataloader = get_dev_dataloader(
|
|
dev_metadata=dev_metadata,
|
|
am=args.model_name,
|
|
batch_size=args.batch_size,
|
|
speaker_dict=args.speaker_dict,
|
|
shuffle=shuffle)
|
|
|
|
exe = paddle.static.Executor(place)
|
|
exe.run()
|
|
|
|
print("onnx_format:", args.onnx_format)
|
|
|
|
quant_post_static(
|
|
executor=exe,
|
|
model_dir=args.inference_dir,
|
|
quantize_model_path=args.inference_dir + "/" + args.model_name +
|
|
"_quant",
|
|
data_loader=dataloader,
|
|
model_filename=args.model_name + ".pdmodel",
|
|
params_filename=args.model_name + ".pdiparams",
|
|
save_model_filename=args.model_name + ".pdmodel",
|
|
save_params_filename=args.model_name + ".pdiparams",
|
|
batch_size=args.batch_size,
|
|
algo=args.algo,
|
|
round_type=args.round_type,
|
|
hist_percent=args.hist_percent,
|
|
is_full_quantize=args.is_full_quantize,
|
|
bias_correction=args.bias_correction,
|
|
onnx_format=args.onnx_format,
|
|
quantizable_op_type=args.quantizable_op_type)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
new_quantizable_op_type = []
|
|
for item in args.quantizable_op_type:
|
|
new_quantizable_op_type.append(''.join(item))
|
|
args.quantizable_op_type = new_quantizable_op_type
|
|
paddle.enable_static()
|
|
quantize(args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|