From 1a46125175b9428bd4c481f79117598330bad535 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 28 Sep 2021 05:46:29 +0000 Subject: [PATCH] add bin for hub --- deepspeech/exps/deepspeech2/bin/test_hub.py | 191 ++++++++++++++++++++ examples/aishell/s0/local/test_hub.sh | 36 ++++ examples/aishell/s0/run.sh | 8 + 3 files changed, 235 insertions(+) create mode 100644 deepspeech/exps/deepspeech2/bin/test_hub.py create mode 100755 examples/aishell/s0/local/test_hub.sh diff --git a/deepspeech/exps/deepspeech2/bin/test_hub.py b/deepspeech/exps/deepspeech2/bin/test_hub.py new file mode 100644 index 00000000..cbda3b4c --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/test_hub.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +import os +import sys +from pathlib import Path + +import paddle + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.io.collator import SpeechCollator +from deepspeech.models.ds2 import DeepSpeech2Model +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +from deepspeech.utils.utility import print_arguments +from deepspeech.utils.utility import UpdateConfig + +logger = Log(__name__).getlog() + + +class DeepSpeech2Tester_hub(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + self.collate_fn_test = SpeechCollator.from_config(config) + self._text_featurizer = TextFeaturizer( + unit_type=config.collator.unit_type, vocab_filepath=None) + + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + #replace the '' with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] + + return result_transcripts + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + self.model.eval() + cfg = self.config + audio_file = self.audio_file + collate_fn_test = self.collate_fn_test + audio, _ = collate_fn_test.process_utterance( + audio_file=audio_file, transcript=" ") + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + vocab_list = collate_fn_test.vocab_list + result_transcripts = self.compute_result_transcripts( + audio, audio_len, vocab_list, cfg.decoding) + logger.info("result_transcripts: " + result_transcripts[0]) + + def run_test(self): + self.resume() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_model() + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + self.output_dir = output_dir + + def setup_model(self): + config = self.config.clone() + with UpdateConfig(config): + config.model.feat_size = self.collate_fn_test.feature_size + config.model.dict_size = self.collate_fn_test.vocab_size + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") + + self.model = model + + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) + + def resume(self): + """Resume from the checkpoint at checkpoints in the output + directory or load a specified checkpoint. + """ + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + +def main_sp(config, args): + exp = DeepSpeech2Tester_hub(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--model_type") + parser.add_argument("--audio_file") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + if not os.path.isfile(args.audio_file): + print("Please input the audio file path") + sys.exit(-1) + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/examples/aishell/s0/local/test_hub.sh b/examples/aishell/s0/local/test_hub.sh new file mode 100755 index 00000000..d01496c4 --- /dev/null +++ b/examples/aishell/s0/local/test_hub.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +if [ $# != 4 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type audio_file" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 +audio_file=$4 + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test_hub.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} \ +--audio_file ${audio_file} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 71191c3a..83846ada 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -15,6 +15,8 @@ avg_ckpt=avg_${avg_num} ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') echo "checkpoint name ${ckpt}" +audio_file="data/tmp.wav" + if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # prepare data bash ./local/data.sh || exit -1 @@ -44,3 +46,9 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # test export ckpt avg_n CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 fi + +# Optionally, you can add LM and test it with runtime. +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # test a single .wav file + CUDA_VISIBLE_DEVICES=0 ./local/test_hub.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} ${audio_file} || exit -1 +fi