From 80eb6b7f01851c42811866d2678007497e684be2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 22 Sep 2021 09:32:00 +0000 Subject: [PATCH] fix espnet kaldi libri s2 config --- examples/librispeech/s2/conf/transformer.yaml | 6 +++--- examples/librispeech/s2/local/test.sh | 2 +- examples/librispeech/s2/run.sh | 6 +++--- utils/avg_model.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index edf5b81d..b86224ff 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -12,7 +12,7 @@ collator: stride_ms: 10.0 window_ms: 25.0 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs - batch_size: 32 + batch_size: 30 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced minibatches: 0 # for debug @@ -59,7 +59,7 @@ model: model_conf: ctc_weight: 0.3 ctc_dropoutrate: 0.0 - ctc_grad_norm_type: instance + ctc_grad_norm_type: batch lsm_weight: 0.1 # label smoothing option length_normalized_loss: false @@ -83,7 +83,7 @@ scheduler_conf: lr_decay: 1.0 decoding: - batch_size: 64 + batch_size: 1 error_rate_type: wer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index efd06f35..893d67b5 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -36,7 +36,7 @@ for type in attention ctc_greedy_search; do # stream decoding only support batchsize=1 batch_size=1 else - batch_size=64 + batch_size=1 fi python3 -u ${BIN_DIR}/test.py \ --model-name u2_kaldi \ diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 46c8ea5d..8dd93736 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -6,7 +6,7 @@ stage=0 stop_stage=100 conf_path=conf/transformer.yaml dict_path=data/train_960_unigram5000_units.txt -avg_num=5 +avg_num=10 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} @@ -20,12 +20,12 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh best exp/${ckpt}/checkpoints ${avg_num} + avg.sh latest exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/utils/avg_model.py b/utils/avg_model.py index 8ec792f5..3a0739c9 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -80,8 +80,8 @@ def main(args): data = json.dumps({ "avg_ckpt": args.dst_model, "ckpt": path_list, - "epoch": selected_epochs.tolist(), - "val_loss": beat_val_scores.tolist(), + "epoch": selected_epochs, + "val_loss": beat_val_scores, }) f.write(data + "\n")