[TTS]Add slim for TTS (#2729)
parent
6f927d55db
commit
3f6afc4834
@ -0,0 +1 @@
|
||||
../../tts3/local/PTQ_static.sh
|
@ -0,0 +1,8 @@
|
||||
train_output_path=$1
|
||||
model_name=$2
|
||||
weight_bits=$3
|
||||
|
||||
python3 ${BIN_DIR}/../PTQ_dynamic.py \
|
||||
--inference_dir ${train_output_path}/inference \
|
||||
--model_name ${model_name} \
|
||||
--weight_bits ${weight_bits}
|
@ -0,0 +1,8 @@
|
||||
train_output_path=$1
|
||||
model_name=$2
|
||||
|
||||
python3 ${BIN_DIR}/../PTQ_static.py \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--inference_dir ${train_output_path}/inference \
|
||||
--model_name ${model_name} \
|
||||
--onnx_forma=True
|
@ -0,0 +1,8 @@
|
||||
train_output_path=$1
|
||||
model_name=$2
|
||||
|
||||
python3 ${BIN_DIR}/../../PTQ_static.py \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--inference_dir ${train_output_path}/inference \
|
||||
--model_name ${model_name} \
|
||||
--onnx_format=True
|
@ -0,0 +1 @@
|
||||
../../voc1/local/PTQ_static.sh
|
@ -0,0 +1 @@
|
||||
../../voc1/local/PTQ_static.sh
|
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import paddle
|
||||
from paddleslim.quant import quant_post_dynamic
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Paddle Slim Dynamic with acoustic model & vocoder.")
|
||||
# acoustic model
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
type=str,
|
||||
default='fastspeech2_csmsc',
|
||||
choices=[
|
||||
'speedyspeech_csmsc',
|
||||
'fastspeech2_csmsc',
|
||||
'fastspeech2_aishell3',
|
||||
'fastspeech2_ljspeech',
|
||||
'fastspeech2_vctk',
|
||||
'tacotron2_csmsc',
|
||||
'fastspeech2_mix',
|
||||
'pwgan_csmsc',
|
||||
'pwgan_aishell3',
|
||||
'pwgan_ljspeech',
|
||||
'pwgan_vctk',
|
||||
'mb_melgan_csmsc',
|
||||
'hifigan_csmsc',
|
||||
'hifigan_aishell3',
|
||||
'hifigan_ljspeech',
|
||||
'hifigan_vctk',
|
||||
'wavernn_csmsc',
|
||||
],
|
||||
help='Choose model type of tts task.')
|
||||
|
||||
parser.add_argument(
|
||||
"--inference_dir", type=str, help="dir to save inference models")
|
||||
parser.add_argument(
|
||||
"--weight_bits",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=[8, 16],
|
||||
help="The bits for the quantized weight, and it should be 8 or 16. Default is 8.",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
# only inference for models trained with csmsc now
|
||||
def main():
|
||||
args = parse_args()
|
||||
paddle.enable_static()
|
||||
quant_post_dynamic(
|
||||
model_dir=args.inference_dir,
|
||||
save_model_dir=args.inference_dir,
|
||||
model_filename=args.model_name + ".pdmodel",
|
||||
params_filename=args.model_name + ".pdiparams",
|
||||
save_model_filename=args.model_name + "_" + str(args.weight_bits) +
|
||||
"bits.pdmodel",
|
||||
save_params_filename=args.model_name + "_" + str(args.weight_bits) +
|
||||
"bits.pdiparams",
|
||||
weight_bits=args.weight_bits, )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,156 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleslim.quant import quant_post_static
|
||||
|
||||
from paddlespeech.t2s.exps.syn_utils import get_dev_dataloader
|
||||
from paddlespeech.t2s.utils import str2bool
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Paddle Slim Static with acoustic model & vocoder.")
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Minibatch size.")
|
||||
parser.add_argument("--batch_num", type=int, default=1, help="Batch number")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
||||
# model_path save_path
|
||||
parser.add_argument(
|
||||
"--inference_dir", type=str, help="dir to save inference models")
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
type=str,
|
||||
default='fastspeech2_csmsc',
|
||||
choices=[
|
||||
'speedyspeech_csmsc',
|
||||
'fastspeech2_csmsc',
|
||||
'fastspeech2_aishell3',
|
||||
'fastspeech2_ljspeech',
|
||||
'fastspeech2_vctk',
|
||||
'fastspeech2_mix',
|
||||
'pwgan_csmsc',
|
||||
'pwgan_aishell3',
|
||||
'pwgan_ljspeech',
|
||||
'pwgan_vctk',
|
||||
'mb_melgan_csmsc',
|
||||
'hifigan_csmsc',
|
||||
'hifigan_aishell3',
|
||||
'hifigan_ljspeech',
|
||||
'hifigan_vctk',
|
||||
],
|
||||
help='Choose model type of tts task.')
|
||||
|
||||
parser.add_argument(
|
||||
"--algo", type=str, default='avg', help="calibration algorithm.")
|
||||
parser.add_argument(
|
||||
"--round_type",
|
||||
type=str,
|
||||
default='round',
|
||||
help="The method of converting the quantized weights.")
|
||||
parser.add_argument(
|
||||
"--hist_percent",
|
||||
type=float,
|
||||
default=0.9999,
|
||||
help="The percentile of algo:hist.")
|
||||
parser.add_argument(
|
||||
"--is_full_quantize",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether is full quantization or not.")
|
||||
parser.add_argument(
|
||||
"--bias_correction",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use bias correction.")
|
||||
parser.add_argument(
|
||||
"--ce_test", type=str2bool, default=False, help="Whether to CE test.")
|
||||
parser.add_argument(
|
||||
"--onnx_format",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to export the quantized model with format of ONNX.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--speaker-dict",
|
||||
type=str,
|
||||
default=None,
|
||||
help="speaker id map file for multiple speaker model.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument(
|
||||
"--quantizable_op_type",
|
||||
type=list,
|
||||
nargs='+',
|
||||
default=[
|
||||
"conv2d_transpose", "conv2d", "depthwise_conv2d", "mul", "matmul",
|
||||
"matmul_v2"
|
||||
],
|
||||
help="The list of op types that will be quantized.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def quantize(args):
|
||||
shuffle = True
|
||||
if args.ce_test:
|
||||
# set seed
|
||||
seed = 111
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
random.seed(seed)
|
||||
shuffle = False
|
||||
|
||||
place = paddle.CUDAPlace(0) if args.ngpu > 0 else paddle.CPUPlace()
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
|
||||
dataloader = get_dev_dataloader(
|
||||
dev_metadata=dev_metadata,
|
||||
am=args.model_name,
|
||||
batch_size=args.batch_size,
|
||||
speaker_dict=args.speaker_dict,
|
||||
shuffle=shuffle)
|
||||
|
||||
exe = paddle.static.Executor(place)
|
||||
exe.run()
|
||||
|
||||
print("onnx_format:", args.onnx_format)
|
||||
|
||||
quant_post_static(
|
||||
executor=exe,
|
||||
model_dir=args.inference_dir,
|
||||
quantize_model_path=args.inference_dir + "/" + args.model_name +
|
||||
"_quant",
|
||||
data_loader=dataloader,
|
||||
model_filename=args.model_name + ".pdmodel",
|
||||
params_filename=args.model_name + ".pdiparams",
|
||||
save_model_filename=args.model_name + ".pdmodel",
|
||||
save_params_filename=args.model_name + ".pdiparams",
|
||||
batch_size=args.batch_size,
|
||||
algo=args.algo,
|
||||
round_type=args.round_type,
|
||||
hist_percent=args.hist_percent,
|
||||
is_full_quantize=args.is_full_quantize,
|
||||
bias_correction=args.bias_correction,
|
||||
onnx_format=args.onnx_format,
|
||||
quantizable_op_type=args.quantizable_op_type)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
new_quantizable_op_type = []
|
||||
for item in args.quantizable_op_type:
|
||||
new_quantizable_op_type.append(''.join(item))
|
||||
args.quantizable_op_type = new_quantizable_op_type
|
||||
paddle.enable_static()
|
||||
quantize(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue