fix espnet kaldi libri s2 config

pull/852/head
Hui Zhang 3 years ago
parent 98b15eda05
commit 80eb6b7f01

@ -12,7 +12,7 @@ collator:
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32 batch_size: 30
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug minibatches: 0 # for debug
@ -59,7 +59,7 @@ model:
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0 ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance ctc_grad_norm_type: batch
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
@ -83,7 +83,7 @@ scheduler_conf:
lr_decay: 1.0 lr_decay: 1.0
decoding: decoding:
batch_size: 64 batch_size: 1
error_rate_type: wer error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm

@ -36,7 +36,7 @@ for type in attention ctc_greedy_search; do
# stream decoding only support batchsize=1 # stream decoding only support batchsize=1
batch_size=1 batch_size=1
else else
batch_size=64 batch_size=1
fi fi
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \ --model-name u2_kaldi \

@ -6,7 +6,7 @@ stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt dict_path=data/train_960_unigram5000_units.txt
avg_num=5 avg_num=10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
@ -20,12 +20,12 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3 ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model # avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num} avg.sh latest exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then

@ -80,8 +80,8 @@ def main(args):
data = json.dumps({ data = json.dumps({
"avg_ckpt": args.dst_model, "avg_ckpt": args.dst_model,
"ckpt": path_list, "ckpt": path_list,
"epoch": selected_epochs.tolist(), "epoch": selected_epochs,
"val_loss": beat_val_scores.tolist(), "val_loss": beat_val_scores,
}) })
f.write(data + "\n") f.write(data + "\n")

Loading…
Cancel
Save