fix wav2vec2 test_wav.sh run error.

pull/3023/head
zxcd 3 years ago
parent 00585f61f1
commit 175db0682d

@ -190,9 +190,9 @@ tar xzvf wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz
``` ```
You can download the audio demo: You can download the audio demo:
```bash ```bash
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/ wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
``` ```
You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below. You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
```bash ```bash
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1 data/demo_002_en.wav CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1 data/demo_01_03.wav
``` ```

@ -107,6 +107,7 @@ vocab_filepath: data/lang_char/vocab.txt
########################################### ###########################################
unit_type: 'char' unit_type: 'char'
tokenizer: bert-base-chinese
mean_std_filepath: mean_std_filepath:
preprocess_config: conf/preprocess.yaml preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
@ -163,3 +164,4 @@ log_interval: 1
checkpoint: checkpoint:
kbest_n: 50 kbest_n: 50
latest_n: 5 latest_n: 5

@ -8,9 +8,7 @@ echo "using $ngpu gpus..."
expdir=exp expdir=exp
datadir=data datadir=data
train_set=train_960 train_set=train
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
config_path=$1 config_path=$1
decode_config_path=$2 decode_config_path=$2
@ -75,7 +73,7 @@ for type in ctc_prefix_beam_search; do
--trans_hyp ${ckpt_prefix}.${type}.rsl.text --trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \ python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error data/manifest.test.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done." echo "decoding ${type} done."
done done

@ -14,7 +14,7 @@ ckpt_prefix=$3
audio_file=$4 audio_file=$4
mkdir -p data mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/ wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi

@ -15,7 +15,7 @@ resume= # xx e.g. 30
export FLAGS_cudnn_deterministic=1 export FLAGS_cudnn_deterministic=1
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; . ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
audio_file=data/demo_002_en.wav audio_file=data/demo_01_03.wav
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')

@ -102,13 +102,11 @@ ssl_dynamic_pretrained_models = {
'params': 'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
}, },
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.4': { '1.4': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz',
'md5': 'md5':
'9f0bc943adb822789bf61e674b229d17', '150e51b8ea5d255ccce6b395de8d916a',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':

@ -18,6 +18,7 @@ from pathlib import Path
import paddle import paddle
import soundfile import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
@ -34,8 +35,13 @@ class Wav2vec2Infer():
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.text_feature = TextFeaturizer( if self.config.tokenizer:
unit_type=config.unit_type, vocab=config.vocab_filepath) self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
else:
self.text_feature = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model # model
@ -59,14 +65,14 @@ class Wav2vec2Infer():
audio, _ = soundfile.read( audio, _ = soundfile.read(
self.audio_file, dtype="int16", always_2d=True) self.audio_file, dtype="int16", always_2d=True)
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode
result_transcripts, result_tokenids = self.model.decode( result_transcripts, result_tokenids = self.model.decode(
xs, xs,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=decode_config.decoding_method, decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size) beam_size=decode_config.beam_size,
tokenizer=self.config.tokenizer, )
rsl = result_transcripts[0] rsl = result_transcripts[0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}") logger.info(f"hyp: {utt} {rsl}")

Loading…
Cancel
Save