fix timit scripts; reader filtype case;

pull/870/head
Hui Zhang 3 years ago
parent 84f77ecdf5
commit 251d32a609

@ -322,7 +322,7 @@ class LoadInputsAndTargets():
"Not supported: loader_type={}".format(filetype)) "Not supported: loader_type={}".format(filetype))
def file_type(self, filepath): def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[-1] suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark': if suffix == 'ark':
return 'mat' return 'mat'
elif suffix == 'scp': elif suffix == 'scp':

@ -14,6 +14,7 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
import editdistance
import numpy as np import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer'] __all__ = ['word_errors', 'char_errors', 'wer', 'cer']
@ -89,6 +90,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
hyp_words = list(filter(None, hypothesis.split(delimiter))) hyp_words = list(filter(None, hypothesis.split(delimiter)))
edit_distance = _levenshtein_distance(ref_words, hyp_words) edit_distance = _levenshtein_distance(ref_words, hyp_words)
# edit_distance = editdistance.eval(ref_words, hyp_words)
return float(edit_distance), len(ref_words) return float(edit_distance), len(ref_words)
@ -119,6 +121,7 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
hypothesis = join_char.join(list(filter(None, hypothesis.split(' ')))) hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
edit_distance = _levenshtein_distance(reference, hypothesis) edit_distance = _levenshtein_distance(reference, hypothesis)
# edit_distance = editdistance.eval(reference, hypothesis)
return float(edit_distance), len(reference) return float(edit_distance), len(reference)

@ -93,20 +93,25 @@ def pad_sequence(sequences: List[paddle.Tensor],
for i, tensor in enumerate(sequences): for i, tensor in enumerate(sequences):
length = tensor.shape[0] length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
logger.info(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
)
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor # out_tensor[i, :length, ...] = tensor
if length != 0: if length != 0:
out_tensor[i, :length, ...] = tensor out_tensor[i, :length] = tensor
else: else:
out_tensor[i, length, ...] = tensor out_tensor[i, length] = tensor
else: else:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor # out_tensor[:length, i, ...] = tensor
if length != 0: if length != 0:
out_tensor[:length, i, ...] = tensor out_tensor[:length, i] = tensor
else: else:
out_tensor[length, i, ...] = tensor out_tensor[length, i] = tensor
return out_tensor return out_tensor

@ -12,4 +12,4 @@
## Transformer ## Transformer
| Model | Params | Config | Char-BLEU | | Model | Params | Config | Char-BLEU |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| Transformer+ASR MTL | 50.26M | conf/transformer_joint_noam.yaml | 17.38 | | Transformer+ASR MTL | 50.26M | conf/transformer_joint_noam.yaml | 17.38 |

@ -1,5 +1,7 @@
#!/bin/bash #!/bin/bash
set -e
stage=-1 stage=-1
stop_stage=100 stop_stage=100

@ -0,0 +1,3 @@
data
exp
test.profile

@ -1,11 +1,9 @@
# TIMIT # TIMIT
### Transformer ### Transformer
| Model | Params | Config | Decode method | PER | | Model | Params | Config | Decode method | Loss | PER |
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| transformer | 5.17M | conf/transformer.yaml | attention | 0.5531 | | transformer | 5.17M | conf/transformer.yaml | attention | 49.25688171386719 | 0.510742 |
| transformer | 5.17M | conf/transformer.yaml | ctc_greedy_search | 0.3922 | | transformer | 5.17M | conf/transformer.yaml | ctc_greedy_search | 49.25688171386719 | 0.382398 |
| transformer | 5.17M | conf/transformer.yaml | ctc_prefix_beam_search | 0.3768 | | transformer | 5.17M | conf/transformer.yaml | ctc_prefix_beam_search | 49.25688171386719 | 0.367429 |
| transformer | 5.17M | conf/transformer.yaml | attention_rescore | 49.25688171386719 | 0.357173 |

@ -1,10 +1,18 @@
#!/bin/bash #!/bin/bash
set -e
stage=0
stop_stage=50
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ $# != 2 ];then if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix" echo "usage: ${0} config_path ckpt_path_prefix"
exit -1 exit -1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
@ -23,44 +31,67 @@ fi
# exit 1 # exit 1
#fi #fi
for type in attention ctc_greedy_search; do if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "decoding ${type}" for type in attention ctc_greedy_search; do
if [ ${chunk_mode} == true ];then echo "decoding ${type}"
# stream decoding only support batchsize=1 if [ ${chunk_mode} == true ];then
batch_size=1 # stream decoding only support batchsize=1
else batch_size=1
batch_size=64 else
fi batch_size=64
python3 -u ${BIN_DIR}/test.py \ fi
--nproc ${ngpu} \ python3 -u ${BIN_DIR}/test.py \
--config ${config_path} \ --nproc ${ngpu} \
--result_file ${ckpt_prefix}.${type}.rsl \ --config ${config_path} \
--checkpoint_path ${ckpt_prefix} \ --result_file ${ckpt_prefix}.${type}.rsl \
--opts decoding.decoding_method ${type} \ --checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size} --opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!" if [ $? -ne 0 ]; then
exit 1 echo "Failed in evaluation!"
fi exit 1
done fi
done
fi
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Failed in evaluation!" for type in ctc_prefix_beam_search; do
exit 1 echo "decoding ${type}"
fi batch_size=1
done python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for type in attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--nproc ${ngpu} \
--config ${config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.decoding_method ${type} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
fi
exit 0 exit 0

@ -1,13 +1,15 @@
#!/bin/bash #!/bin/bash
set -e set -e
source path.sh
. path.sh || exit 1;
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
avg_num=10 avg_num=10
TIMIT_path= #path of TIMIT (Required, e.g. /export/corpora5/LDC/LDC93S1/timit/TIMIT) TIMIT_path=/workspace/zhanghui/dataset/data/lisa/data/timit/raw/TIMIT
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')

@ -19,3 +19,4 @@ tqdm
typeguard typeguard
visualdl==2.2.0 visualdl==2.2.0
yacs yacs
editdistance
Loading…
Cancel
Save