fix wav2vec2 test_wav.sh run error.

pull/3023/head
zxcd 3 years ago
parent 00585f61f1
commit 175db0682d

@ -190,9 +190,9 @@ tar xzvf wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz
```
You can download the audio demo:
```bash
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
```
You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
```bash
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1 data/demo_002_en.wav
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1 data/demo_01_03.wav
```

@ -107,6 +107,7 @@ vocab_filepath: data/lang_char/vocab.txt
###########################################
unit_type: 'char'
tokenizer: bert-base-chinese
mean_std_filepath:
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
@ -163,3 +164,4 @@ log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5

@ -8,9 +8,7 @@ echo "using $ngpu gpus..."
expdir=exp
datadir=data
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
train_set=train
config_path=$1
decode_config_path=$2
@ -75,7 +73,7 @@ for type in ctc_prefix_beam_search; do
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
data/manifest.test.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done."
done

@ -14,7 +14,7 @@ ckpt_prefix=$3
audio_file=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi

@ -15,7 +15,7 @@ resume= # xx e.g. 30
export FLAGS_cudnn_deterministic=1
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
audio_file=data/demo_002_en.wav
audio_file=data/demo_01_03.wav
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')

@ -102,13 +102,11 @@ ssl_dynamic_pretrained_models = {
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz',
'md5':
'9f0bc943adb822789bf61e674b229d17',
'150e51b8ea5d255ccce6b395de8d916a',
'cfg_path':
'model.yaml',
'ckpt_path':

@ -18,6 +18,7 @@ from pathlib import Path
import paddle
import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
@ -34,8 +35,13 @@ class Wav2vec2Infer():
self.config = config
self.audio_file = args.audio_file
self.text_feature = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
if self.config.tokenizer:
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
else:
self.text_feature = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
@ -59,14 +65,14 @@ class Wav2vec2Infer():
audio, _ = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
logger.info(f"audio shape: {audio.shape}")
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts, result_tokenids = self.model.decode(
xs,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size)
beam_size=decode_config.beam_size,
tokenizer=self.config.tokenizer, )
rsl = result_transcripts[0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}")

Loading…
Cancel
Save