From 789471bfca51bb7fde80c7ba02cc460828f10ade Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 25 Nov 2021 07:27:44 +0000 Subject: [PATCH] test wav for u2 --- examples/wenetspeech/asr1/local/test_wav.sh | 45 +++++ paddlespeech/s2t/exps/u2/bin/test_hub.py | 187 -------------------- paddlespeech/s2t/exps/u2/bin/test_wav.py | 148 ++++++++++++++++ 3 files changed, 193 insertions(+), 187 deletions(-) create mode 100755 examples/wenetspeech/asr1/local/test_wav.sh delete mode 100644 paddlespeech/s2t/exps/u2/bin/test_hub.py create mode 100644 paddlespeech/s2t/exps/u2/bin/test_wav.py diff --git a/examples/wenetspeech/asr1/local/test_wav.sh b/examples/wenetspeech/asr1/local/test_wav.sh new file mode 100755 index 00000000..13296af2 --- /dev/null +++ b/examples/wenetspeech/asr1/local/test_wav.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix audio_file" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +audio_file=$3 + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi + +# download language model +#bash local/download_lm_ch.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + +for type in attention_rescoring; do + echo "decoding ${type}" + batch_size=1 + output_dir=${ckpt_prefix} + mkdir -p ${output_dir} + python3 -u ${BIN_DIR}/test_wav.py \ + --nproc ${ngpu} \ + --config ${config_path} \ + --result_file ${output_dir}/${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} \ + --opts decoding.batch_size ${batch_size} \ + --audio_file ${audio_file} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done +exit 0 diff --git a/paddlespeech/s2t/exps/u2/bin/test_hub.py b/paddlespeech/s2t/exps/u2/bin/test_hub.py deleted file mode 100644 index 55a61d5c..00000000 --- a/paddlespeech/s2t/exps/u2/bin/test_hub.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Evaluation for U2 model.""" -import cProfile -import os -import sys - -import paddle -import soundfile - -from paddlespeech.s2t.exps.u2.config import get_cfg_defaults -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.collator import SpeechCollator -from paddlespeech.s2t.models.u2 import U2Model -from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.training.trainer import Trainer -from paddlespeech.s2t.utils import layer_tools -from paddlespeech.s2t.utils import mp_tools -from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.utility import print_arguments -from paddlespeech.s2t.utils.utility import UpdateConfig -logger = Log(__name__).getlog() - -# TODO(hui zhang): dynamic load - - -class U2Tester_Hub(Trainer): - def __init__(self, config, args): - # super().__init__(config, args) - self.args = args - self.config = config - self.audio_file = args.audio_file - self.collate_fn_test = SpeechCollator.from_config(config) - self._text_featurizer = TextFeaturizer( - unit_type=config.collator.unit_type, - vocab_filepath=None, - spm_model_prefix=config.collator.spm_model_prefix) - - def setup_model(self): - config = self.config - model_conf = config.model - - with UpdateConfig(model_conf): - model_conf.input_dim = self.collate_fn_test.feature_size - model_conf.output_dim = self.collate_fn_test.vocab_size - - model = U2Model.from_config(model_conf) - - if self.parallel: - model = paddle.DataParallel(model) - - logger.info(f"{model}") - layer_tools.print_params(model, logger.info) - - self.model = model - logger.info("Setup model") - - @mp_tools.rank_zero_only - @paddle.no_grad() - def test(self): - self.model.eval() - cfg = self.config.decoding - audio_file = self.audio_file - collate_fn_test = self.collate_fn_test - audio, _ = collate_fn_test.process_utterance( - audio_file=audio_file, transcript="Hello") - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - vocab_list = collate_fn_test.vocab_list - - text_feature = self.collate_fn_test.text_feature - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=text_feature, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - logger.info("The result_transcripts: " + result_transcripts[0][0]) - - def run_test(self): - self.resume() - try: - self.test() - except KeyboardInterrupt: - sys.exit(-1) - - def setup(self): - """Setup the experiment. - """ - paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') - - #self.setup_output_dir() - #self.setup_checkpointer() - - #self.setup_dataloader() - self.setup_model() - - self.iteration = 0 - self.epoch = 0 - - def resume(self): - """Resume from the checkpoint at checkpoints in the output - directory or load a specified checkpoint. - """ - params_path = self.args.checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - self.model.set_state_dict(model_dict) - - -def check(audio_file): - logger.info("checking the audio file format......") - try: - sig, sample_rate = soundfile.read(audio_file) - except Exception as e: - logger.error(str(e)) - logger.error( - "can not open the wav file, please check the audio file format") - sys.exit(-1) - logger.info("The sample rate is %d" % sample_rate) - assert (sample_rate == 16000) - logger.info("The audio file format is right") - - -def main_sp(config, args): - exp = U2Tester_Hub(config, args) - with exp.eval(): - exp.setup() - exp.run_test() - - -def main(config, args): - main_sp(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") - parser.add_argument( - "--audio_file", type=str, help="path of the input audio file") - args = parser.parse_args() - print_arguments(args, globals()) - - if not os.path.isfile(args.audio_file): - print("Please input the right audio file path") - sys.exit(-1) - check(args.audio_file) - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - config.freeze() - print(config) - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - # Setting for profiling - pr = cProfile.Profile() - pr.runcall(main, config, args) - pr.dump_stats('test.profile') diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py new file mode 100644 index 00000000..e118b481 --- /dev/null +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for U2 model.""" +import os +import sys +from pathlib import Path + +import paddle +import soundfile + +from paddlespeech.s2t.exps.u2.config import get_cfg_defaults +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.models.u2 import U2Model +from paddlespeech.s2t.training.cli import default_argument_parser +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.log import Log +from paddlespeech.s2t.utils.utility import UpdateConfig +logger = Log(__name__).getlog() + +# TODO(hui zhang): dynamic load + + +class U2Infer(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + self.sr = config.collator.target_sample_rate + + self.preprocess_conf = config.collator.augmentation_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + self.text_feature = TextFeaturizer( + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix) + + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') + + # model + model_conf = config.model + with UpdateConfig(model_conf): + model_conf.input_dim = config.collator.feat_dim + model_conf.output_dim = self.text_feature.vocab_size + model = U2Model.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + def run(self): + check(args.audio_file) + + with paddle.no_grad(): + # read + audio, sample_rate = soundfile.read( + self.audio_file, dtype="int16", always_2d=True) + if sample_rate != self.sr: + logger.error( + f"sample rate error: {sample_rate}, need {self.sr} ") + sys.exit(-1) + + audio = audio[:, 0] + logger.info(f"audio shape: {audio.shape}") + + # fbank + feat = self.preprocessing(audio, **self.preprocess_args) + logger.info(f"feat shape: {feat.shape}") + + ilen = paddle.to_tensor(feat.shape[0]) + xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) + + cfg = self.config.decoding + result_transcripts = self.model.decode( + xs, + ilen, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + rsl = result_transcripts[0][0] + utt = Path(self.audio_file).name + logger.info(f"hyp: {utt} {result_transcripts[0][0]}") + return rsl + + +def check(audio_file): + if not os.path.isfile(audio_file): + print("Please input the right audio file path") + sys.exit(-1) + + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main(config, args): + U2Infer(config, args).run() + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + parser.add_argument( + "--audio_file", type=str, help="path of the input audio file") + args = parser.parse_args() + + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + main(config, args)