From 3ed24474d2ee85d3aee71de37c9b84c97094f5ef Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 8 Oct 2022 02:34:10 +0000 Subject: [PATCH] wenetspeech asr1 quant --- examples/wenetspeech/asr1/local/quant.sh | 59 ++++++ paddlespeech/s2t/exps/u2/bin/quant.py | 220 +++++++++++++++++++++++ 2 files changed, 279 insertions(+) create mode 100755 examples/wenetspeech/asr1/local/quant.sh create mode 100644 paddlespeech/s2t/exps/u2/bin/quant.py diff --git a/examples/wenetspeech/asr1/local/quant.sh b/examples/wenetspeech/asr1/local/quant.sh new file mode 100755 index 00000000..9dfea904 --- /dev/null +++ b/examples/wenetspeech/asr1/local/quant.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +if [ $# != 4 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +decode_config_path=$2 +ckpt_prefix=$3 +audio_file=$4 + +mkdir -p data +wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/ +if [ $? -ne 0 ]; then + exit 1 +fi + +if [ ! -f ${audio_file} ]; then + echo "Plase input the right audio_file path" + exit 1 +fi + + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi + +# download language model +#bash local/download_lm_ch.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + +for type in attention_rescoring; do + echo "decoding ${type}" + batch_size=1 + output_dir=${ckpt_prefix} + mkdir -p ${output_dir} + python3 -u ${BIN_DIR}/quant.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${output_dir}/${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} \ + --audio_file ${audio_file} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done +exit 0 diff --git a/paddlespeech/s2t/exps/u2/bin/quant.py b/paddlespeech/s2t/exps/u2/bin/quant.py new file mode 100644 index 00000000..de7c27e7 --- /dev/null +++ b/paddlespeech/s2t/exps/u2/bin/quant.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" +import os +import sys +from pathlib import Path + +import paddle +import soundfile +from yacs.config import CfgNode + +from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.models.u2 import U2Model +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import UpdateConfig +from paddleslim import PTQ +logger = Log(__name__).getlog() + + +class U2Infer(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + + self.preprocess_conf = config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) + self.text_feature = TextFeaturizer( + unit_type=config.unit_type, + vocab=config.vocab_filepath, + spm_model_prefix=config.spm_model_prefix) + + paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') + + # model + model_conf = config + with UpdateConfig(model_conf): + model_conf.input_dim = config.feat_dim + model_conf.output_dim = self.text_feature.vocab_size + model = U2Model.from_config(model_conf) + self.model = model + self.model.eval() + self.ptq = PTQ() + self.model = self.ptq.quantize(model) + + # load model + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + logger.info(f"model_dict: {model_dict.keys()}") + + def run(self): + check(args.audio_file) + + with paddle.no_grad(): + # read + audio, sample_rate = soundfile.read( + self.audio_file, dtype="int16", always_2d=True) + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + + # fbank + feat = self.preprocessing(audio, **self.preprocess_args) + logger.info(f"feat shape: {feat.shape}") + + ilen = paddle.to_tensor(feat.shape[0]) + xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) + decode_config = self.config.decode + logger.info(f"decode cfg: {decode_config}") + result_transcripts = self.model.decode( + xs, + ilen, + text_feature=self.text_feature, + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size, + ctc_weight=decode_config.ctc_weight, + decoding_chunk_size=decode_config.decoding_chunk_size, + num_decoding_left_chunks=decode_config.num_decoding_left_chunks, + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=self.reverse_weight) + rsl = result_transcripts[0][0] + utt = Path(self.audio_file).name + logger.info(f"hyp: {utt} {result_transcripts[0][0]}") + # print(self.model) + # print(self.model.forward_encoder_chunk) + # return rsl + + logger.info("-------------start export ----------------------") + batch_size = 1 + feat_dim = 80 + model_size = 512 + num_left_chunks = -1 + logger.info( + f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}" + ) + + # ######################## self.model.forward_encoder_chunk ############ + # input_spec = [ + # # (T,), int16 + # paddle.static.InputSpec(shape=[None], dtype='int16'), + # ] + # self.model.forward_feature = paddle.jit.to_static( + # self.model.forward_feature, input_spec=input_spec) + + ######################### self.model.forward_encoder_chunk ############ + input_spec = [ + # xs, (B, T, D) + paddle.static.InputSpec( + shape=[batch_size, None, feat_dim], dtype='float32'), + # offset, int, but need be tensor + paddle.static.InputSpec(shape=[1], dtype='int32'), + # required_cache_size, int + num_left_chunks, + # att_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32'), + # cnn_cache + paddle.static.InputSpec( + shape=[None, None, None, None], dtype='float32') + ] + self.model.forward_encoder_chunk = paddle.jit.to_static( + self.model.forward_encoder_chunk, input_spec=input_spec) + + ######################### self.model.ctc_activation ######################## + input_spec = [ + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32') + ] + self.model.ctc_activation = paddle.jit.to_static( + self.model.ctc_activation, input_spec=input_spec) + + ######################### self.model.forward_attention_decoder ######################## + reverse_weight = 0.3 + input_spec = [ + # hyps, (B, U) + paddle.static.InputSpec(shape=[None, None], dtype='int64'), + # hyps_lens, (B,) + paddle.static.InputSpec(shape=[None], dtype='int64'), + # encoder_out, (B,T,D) + paddle.static.InputSpec( + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight + ] + self.model.forward_attention_decoder = paddle.jit.to_static( + self.model.forward_attention_decoder, input_spec=input_spec) + ################################################################################ + + # jit save + logger.info(f"export save: {self.args.export_path}") + config = {'is_static': True, 'combine_params':True, 'skip_forward':True} + self.ptq.save_quantized_model(self.model, self.args.export_path) + # paddle.jit.save( + # self.model, + # self.args.export_path, + # combine_params=True, + # skip_forward=True) + + + +def check(audio_file): + if not os.path.isfile(audio_file): + print("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main(config, args): + U2Infer(config, args).run() + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + "--audio_file", type=str, help="path of the input audio file") + parser.add_argument( + "--export_path", type=str, default='export', help="path of the input audio file") + args = parser.parse_args() + + config = CfgNode(new_allowed=True) + + if args.config: + config.merge_from_file(args.config) + if args.decode_cfg: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_cfg) + config.decode = decode_confs + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + main(config, args)