diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index 907e3a94..2ba96e76 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -1,97 +1,93 @@ -# network architecture -model: - cmvn_file: - cmvn_file_type: "json" - # encoder related - encoder: conformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: True - cnn_module_kernel: 15 - use_cnn_module: True - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' +############################################ +# Network Architecture # +############################################ +#model: +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test +########################################### +# Data # +########################################### +#data: +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test +########################################### +# Dataloader # +########################################### +#collator: +vocab_filepath: data/lang_char/vocab.txt +unit_type: 'char' +augmentation_config: conf/preprocess.yaml +spm_model_prefix: '' +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 64 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 0 +subsampling_factor: 1 +num_encs: 1 -collator: - vocab_filepath: data/lang_char/vocab.txt - unit_type: 'char' - augmentation_config: conf/preprocess.yaml - feat_dim: 80 - stride_ms: 10.0 - window_ms: 25.0 - sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs - batch_size: 64 - maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced - maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced - minibatches: 0 # for debug - batch_count: auto - batch_bins: 0 - batch_frames_in: 0 - batch_frames_out: 0 - batch_frames_inout: 0 - num_workers: 0 - subsampling_factor: 1 - num_encs: 1 - - -training: - n_epoch: 240 - accum_grad: 2 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.002 - weight_decay: 1e-6 - scheduler: warmuplr - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - beam_size: 10 - batch_size: 128 - error_rate_type: cer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. +########################################### +# training # +########################################### +#training: +n_epoch: 240 +accum_grad: 2 +global_grad_clip: 5.0 +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/aishell/asr1/conf/decode.yaml b/examples/aishell/asr1/conf/decode.yaml new file mode 100644 index 00000000..49364f5d --- /dev/null +++ b/examples/aishell/asr1/conf/decode.yaml @@ -0,0 +1,12 @@ +#decoding: +beam_size: 10 +decode_batch_size: 128 +error_rate_type: cer +decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' +ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. +decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. +num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. +simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/aishell/asr1/local/align.sh b/examples/aishell/asr1/local/align.sh index c65d611c..f526c8a4 100755 --- a/examples/aishell/asr1/local/align.sh +++ b/examples/aishell/asr1/local/align.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi @@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." config_path=$1 -ckpt_prefix=$2 +decode_config_path=$2 +ckpt_prefix=$3 batch_size=1 output_dir=${ckpt_prefix} @@ -20,9 +21,10 @@ mkdir -p ${output_dir} python3 -u ${BIN_DIR}/alignment.py \ --ngpu ${ngpu} \ --config ${config_path} \ +--decode_config ${decode_config_path} \ --result_file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ ---opts decoding.batch_size ${batch_size} +--opts decoding.decode_batch_size ${batch_size} if [ $? -ne 0 ]; then echo "Failed in ctc alignment!" diff --git a/examples/aishell/asr1/local/test.sh b/examples/aishell/asr1/local/test.sh index da159de7..2c092127 100755 --- a/examples/aishell/asr1/local/test.sh +++ b/examples/aishell/asr1/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" exit -1 fi @@ -9,7 +9,8 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." config_path=$1 -ckpt_prefix=$2 +decode_config_path=$2 +ckpt_prefix=$3 chunk_mode=false if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then @@ -36,10 +37,11 @@ for type in attention ctc_greedy_search; do python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ + --decode_config ${decode_config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} \ - --opts decoding.batch_size ${batch_size} + --opts decoding.decode_batch_size ${batch_size} if [ $? -ne 0 ]; then echo "Failed in evaluation!" @@ -55,6 +57,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do python3 -u ${BIN_DIR}/test.py \ --ngpu ${ngpu} \ --config ${config_path} \ + --decode_config ${decode_config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} \ diff --git a/examples/aishell/asr1/local/test_wav.sh b/examples/aishell/asr1/local/test_wav.sh index f85c1a47..4866e642 100755 --- a/examples/aishell/asr1/local/test_wav.sh +++ b/examples/aishell/asr1/local/test_wav.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: ${0} config_path ckpt_path_prefix audio_file" +if [ $# != 4 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" exit -1 fi @@ -9,8 +9,9 @@ ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." config_path=$1 -ckpt_prefix=$2 -audio_file=$3 +decode_config_path=$2 +ckpt_prefix=$3 +audio_file=$4 mkdir -p data wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/ @@ -42,10 +43,11 @@ for type in attention_rescoring; do python3 -u ${BIN_DIR}/test_wav.py \ --ngpu ${ngpu} \ --config ${config_path} \ + --decode_config ${decode_config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} \ - --opts decoding.batch_size ${batch_size} \ + --opts decoding.decode_batch_size ${batch_size} \ --audio_file ${audio_file} if [ $? -ne 0 ]; then diff --git a/examples/aishell/asr1/run.sh b/examples/aishell/asr1/run.sh index d07a4ed5..11aff9c7 100644 --- a/examples/aishell/asr1/run.sh +++ b/examples/aishell/asr1/run.sh @@ -6,6 +6,7 @@ gpus=0,1,2,3 stage=0 stop_stage=50 conf_path=conf/conformer.yaml +decode_conf_path=conf/decode.yaml avg_num=20 audio_file=data/demo_01_03.wav @@ -32,18 +33,18 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # ctc alignment of test data - CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi # Optionally, you can add LM and test it with runtime. if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi # Not supported at now!!! diff --git a/paddlespeech/s2t/exps/u2/bin/alignment.py b/paddlespeech/s2t/exps/u2/bin/alignment.py index df95baeb..f8397ed0 100644 --- a/paddlespeech/s2t/exps/u2/bin/alignment.py +++ b/paddlespeech/s2t/exps/u2/bin/alignment.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Alignment for U2 model.""" +from yacs.config import CfgNode + from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser @@ -41,6 +43,10 @@ if __name__ == "__main__": config = get_cfg_defaults() if args.config: config.merge_from_file(args.config) + if args.decode_config: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_config) + config.decoding = decode_confs if args.opts: config.merge_from_list(args.opts) config.freeze() diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index 48b0670d..f179ea81 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -14,12 +14,14 @@ """Evaluation for U2 model.""" import cProfile +from yacs.config import CfgNode + from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.utility import print_arguments -# TODO(hui zhang): dynamic load +# TODO(hui zhang): dynamic load def main_sp(config, args): @@ -35,7 +37,7 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to + # save asr result to parser.add_argument( "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() @@ -45,6 +47,10 @@ if __name__ == "__main__": config = get_cfg_defaults() if args.config: config.merge_from_file(args.config) + if args.decode_config: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_config) + config.decoding = decode_confs if args.opts: config.merge_from_list(args.opts) config.freeze() diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 556316ec..e5671a43 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -18,6 +18,7 @@ from pathlib import Path import paddle import soundfile +from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer @@ -36,23 +37,22 @@ class U2Infer(): self.args = args self.config = config self.audio_file = args.audio_file - self.sr = config.collator.target_sample_rate - self.preprocess_conf = config.collator.augmentation_config + self.preprocess_conf = config.augmentation_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) self.text_feature = TextFeaturizer( - unit_type=config.collator.unit_type, - vocab=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix) + unit_type=config.unit_type, + vocab=config.vocab_filepath, + spm_model_prefix=config.spm_model_prefix) paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') # model - model_conf = config.model + model_conf = config with UpdateConfig(model_conf): - model_conf.input_dim = config.collator.feat_dim + model_conf.input_dim = config.feat_dim model_conf.output_dim = self.text_feature.vocab_size model = U2Model.from_config(model_conf) self.model = model @@ -70,10 +70,6 @@ class U2Infer(): # read audio, sample_rate = soundfile.read( self.audio_file, dtype="int16", always_2d=True) - if sample_rate != self.sr: - logger.error( - f"sample rate error: {sample_rate}, need {self.sr} ") - sys.exit(-1) audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") @@ -85,17 +81,17 @@ class U2Infer(): ilen = paddle.to_tensor(feat.shape[0]) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) - cfg = self.config.decoding + decode_config = self.config.decoding result_transcripts = self.model.decode( xs, ilen, text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size, + ctc_weight=decode_config.ctc_weight, + decoding_chunk_size=decode_config.decoding_chunk_size, + num_decoding_left_chunks=decode_config.num_decoding_left_chunks, + simulate_streaming=decode_config.simulate_streaming) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") @@ -136,6 +132,10 @@ if __name__ == "__main__": config = get_cfg_defaults() if args.config: config.merge_from_file(args.config) + if args.decode_config: + decode_confs = CfgNode(new_allowed=True) + decode_confs.merge_from_file(args.decode_config) + config.decoding = decode_confs if args.opts: config.merge_from_list(args.opts) config.freeze() diff --git a/paddlespeech/s2t/exps/u2/config.py b/paddlespeech/s2t/exps/u2/config.py index 898b0bb2..59376e95 100644 --- a/paddlespeech/s2t/exps/u2/config.py +++ b/paddlespeech/s2t/exps/u2/config.py @@ -19,19 +19,18 @@ from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.models.u2 import U2Model -_C = CfgNode() +_C = CfgNode(new_allowed=True) -_C.data = ManifestDataset.params() +ManifestDataset.params(_C) -_C.collator = SpeechCollator.params() +SpeechCollator.params(_C) -_C.model = U2Model.params() +U2Model.params(_C) -_C.training = U2Trainer.params() +U2Trainer.params(_C) _C.decoding = U2Tester.params() - def get_cfg_defaults(): """Get a yacs CfgNode object with default values for my_project.""" # Return a clone so that the defaults will not be altered diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 9fb7067f..1de9541d 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -77,7 +77,7 @@ class U2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): - train_conf = self.config.training + train_conf = self.config start = time.time() # forward @@ -120,7 +120,7 @@ class U2Trainer(Trainer): for k, v in losses_np.items(): report(k, v) - report("batch_size", self.config.collator.batch_size) + report("batch_size", self.config.batch_size) report("accum", train_conf.accum_grad) report("step_cost", iteration_time) @@ -153,7 +153,7 @@ class U2Trainer(Trainer): if ctc_loss: valid_losses['val_ctc_loss'].append(float(ctc_loss)) - if (i + 1) % self.config.training.log_interval == 0: + if (i + 1) % self.config.log_interval == 0: valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} valid_dump['val_history_loss'] = total_loss / num_seen_utts @@ -182,7 +182,7 @@ class U2Trainer(Trainer): self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") - while self.epoch < self.config.training.n_epoch: + while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() try: @@ -214,8 +214,7 @@ class U2Trainer(Trainer): k.split(',')) == 2 else "" msg += "," msg = msg[:-1] # remove the last "," - if (batch_index + 1 - ) % self.config.training.log_interval == 0: + if (batch_index + 1) % self.config.log_interval == 0: logger.info(msg) data_start_time = time.time() except Exception as e: @@ -252,29 +251,29 @@ class U2Trainer(Trainer): if self.train: # train/valid dataset, return token ids self.train_loader = BatchDataLoader( - json_file=config.data.train_manifest, + json_file=config.train_manifest, train_mode=True, - sortagrad=config.collator.sortagrad, - batch_size=config.collator.batch_size, - maxlen_in=config.collator.maxlen_in, - maxlen_out=config.collator.maxlen_out, - minibatches=config.collator.minibatches, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, mini_batch_size=self.args.ngpu, - batch_count=config.collator.batch_count, - batch_bins=config.collator.batch_bins, - batch_frames_in=config.collator.batch_frames_in, - batch_frames_out=config.collator.batch_frames_out, - batch_frames_inout=config.collator.batch_frames_inout, - preprocess_conf=config.collator.augmentation_config, - n_iter_processes=config.collator.num_workers, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.augmentation_config, + n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1) self.valid_loader = BatchDataLoader( - json_file=config.data.dev_manifest, + json_file=config.dev_manifest, train_mode=False, sortagrad=False, - batch_size=config.collator.batch_size, + batch_size=config.batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, @@ -284,18 +283,18 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator.augmentation_config, - n_iter_processes=config.collator.num_workers, + preprocess_conf=config.augmentation_config, + n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1) logger.info("Setup train/valid Dataloader!") else: # test dataset, return raw text self.test_loader = BatchDataLoader( - json_file=config.data.test_manifest, + json_file=config.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.decoding.batch_size, + batch_size=config.decoding.decode_batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, @@ -305,16 +304,16 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator.augmentation_config, + preprocess_conf=config.augmentation_config, n_iter_processes=1, subsampling_factor=1, num_encs=1) self.align_loader = BatchDataLoader( - json_file=config.data.test_manifest, + json_file=config.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.decoding.batch_size, + batch_size=config.decoding.decode_batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, @@ -324,7 +323,7 @@ class U2Trainer(Trainer): batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, - preprocess_conf=config.collator.augmentation_config, + preprocess_conf=config.augmentation_config, n_iter_processes=1, subsampling_factor=1, num_encs=1) @@ -332,7 +331,7 @@ class U2Trainer(Trainer): def setup_model(self): config = self.config - model_conf = config.model + model_conf = config with UpdateConfig(model_conf): if self.train: @@ -355,7 +354,7 @@ class U2Trainer(Trainer): if not self.train: return - train_config = config.training + train_config = config optim_type = train_config.optim optim_conf = train_config.optim_conf scheduler_type = train_config.scheduler @@ -375,7 +374,7 @@ class U2Trainer(Trainer): config, parameters, lr_scheduler=None, ): - train_config = config.training + train_config = config optim_type = train_config.optim optim_conf = train_config.optim_conf scheduler_type = train_config.scheduler @@ -415,7 +414,7 @@ class U2Tester(U2Trainer): error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' num_proc_bsearch=8, # # of CPUs for beam search. beam_size=10, # Beam search width. - batch_size=16, # decoding batch size + decode_batch_size=16, # decoding batch size ctc_weight=0.0, # ctc weight for attention rescoring decode mode. decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. @@ -432,9 +431,9 @@ class U2Tester(U2Trainer): def __init__(self, config, args): super().__init__(config, args) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list def id2token(self, texts, texts_len, text_feature): @@ -453,10 +452,10 @@ class U2Tester(U2Trainer): texts, texts_len, fout=None): - cfg = self.config.decoding + decode_config = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors - error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer start_time = time.time() target_transcripts = self.id2token(texts, texts_len, self.text_feature) @@ -464,12 +463,12 @@ class U2Tester(U2Trainer): audio, audio_len, text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size, + ctc_weight=decode_config.ctc_weight, + decoding_chunk_size=decode_config.decoding_chunk_size, + num_decoding_left_chunks=decode_config.num_decoding_left_chunks, + simulate_streaming=decode_config.simulate_streaming) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( @@ -488,15 +487,15 @@ class U2Tester(U2Trainer): logger.info(f"Utt: {utt}") logger.info(f"Ref: {target}") logger.info(f"Hyp: {result}") - logger.info("One example error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) + logger.info("One example error rate [%s] = %f" % ( + decode_config.error_rate_type, error_rate_func(target, result))) return dict( errors_sum=errors_sum, len_refs=len_refs, num_ins=num_ins, # num examples error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type, + error_rate_type=decode_config.error_rate_type, num_frames=audio_len.sum().numpy().item(), decode_time=decode_time) @@ -507,7 +506,7 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - stride_ms = self.config.collator.stride_ms + stride_ms = self.config.stride_ms error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 @@ -558,15 +557,15 @@ class U2Tester(U2Trainer): "ref_len": len_refs, "decode_method": - self.config.decoding.decoding_method, + self.config.decoding_method, }) f.write(data + '\n') @paddle.no_grad() def align(self): ctc_utils.ctc_align(self.config, self.model, self.align_loader, - self.config.decoding.batch_size, - self.config.collator.stride_ms, self.vocab_list, + self.config.decoding.decode_batch_size, + self.config.stride_ms, self.vocab_list, self.args.result_file) def load_inferspec(self): @@ -577,10 +576,10 @@ class U2Tester(U2Trainer): List[paddle.static.InputSpec]: input spec. """ from paddlespeech.s2t.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.test_loader, - self.config.model.clone(), + infer_model = U2InferModel.from_pretrained(self.train_loader, + self.config.clone(), self.args.checkpoint_path) - feat_dim = self.test_loader.feat_dim + feat_dim = self.train_loader.feat_dim input_spec = [ paddle.static.InputSpec(shape=[1, None, feat_dim], dtype='float32'), # audio, [B,T,D] diff --git a/paddlespeech/s2t/training/cli.py b/paddlespeech/s2t/training/cli.py index 3ef871c5..d4299ea3 100644 --- a/paddlespeech/s2t/training/cli.py +++ b/paddlespeech/s2t/training/cli.py @@ -97,6 +97,14 @@ def default_argument_parser(parser=None): train_group.add_argument( "--dump-config", metavar="FILE", help="dump config to `this` file.") + test_group = parser.add_argument_group( + title='Test Options', description=None) + + test_group.add_argument( + "--decode_config", + metavar="DECODE_CONFIG_FILE", + help="decode config file.") + profile_group = parser.add_argument_group( title='Benchmark Options', description=None) profile_group.add_argument( diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 9bf1ca4d..4b2011ec 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -117,8 +117,8 @@ class Trainer(): self.init_parallel() self.checkpoint = Checkpoint( - kbest_n=self.config.training.checkpoint.kbest_n, - latest_n=self.config.training.checkpoint.latest_n) + kbest_n=self.config.checkpoint.kbest_n, + latest_n=self.config.checkpoint.latest_n) # set random seed if needed if args.seed: @@ -129,8 +129,8 @@ class Trainer(): if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): - self.config.collator.batch_size = self.args.benchmark_batch_size - self.config.training.log_interval = 1 + self.config.batch_size = self.args.benchmark_batch_size + self.config.log_interval = 1 logger.info( f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") @@ -260,7 +260,7 @@ class Trainer(): self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") - while self.epoch < self.config.training.n_epoch: + while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() try: diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index 73c79816..dc1be815 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -130,7 +130,7 @@ def get_subsample(config): Returns: int: subsample rate. """ - input_layer = config["model"]["encoder_conf"]["input_layer"] + input_layer = config["encoder_conf"]["input_layer"] assert input_layer in ["conv2d", "conv2d6", "conv2d8"] if input_layer == "conv2d": return 4