format wav2vec2 demo

pull/2518/head
tianhao zhang 2 years ago
parent 7bee9d807f
commit 3d994f5c23

@ -1,4 +1,4 @@
process:
# extract kaldi fbank from PCM
# use raw audio
- type: wav_process
dither: 0.1
dither: 0.0

@ -1,11 +1,4 @@
decode_batch_size: 1
error_rate_type: wer
decoding_method: ctc_greedy_search # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
beam_size: 10
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -36,9 +36,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# attetion resocre decoder
# greedy search decoder
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi

@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi():
class WavProcess():
def __init__(self, dither=0.1):
def __init__(self, dither=0.0):
"""
Args:
dither (float): Dithering constant

@ -20,8 +20,6 @@ from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args):
exp = Tester(config, args)

@ -25,9 +25,7 @@ import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory
@ -300,7 +298,6 @@ class Wav2Vec2ASRTrainer(Trainer):
"epsilon": optim_conf.epsilon,
"rho": optim_conf.rho,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}

@ -39,7 +39,7 @@ class Wav2vec2ASR(nn.Layer):
enc_n_units=config.dnn_neurons,
blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate,
reduction=True)
reduction='mean')
def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav:

@ -53,7 +53,7 @@ class CTCDecoderBase(nn.Layer):
enc_n_units,
blank_id=0,
dropout_rate: float=0.0,
reduction: bool=True,
reduction: Union[str, bool]=True,
batch_average: bool=True,
grad_norm_type: Union[str, None]=None):
"""CTC decoder
@ -73,7 +73,10 @@ class CTCDecoderBase(nn.Layer):
self.odim = odim
self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = Linear(enc_n_units, self.odim)
if isinstance(reduction, bool):
reduction_type = "sum" if reduction else "none"
else:
reduction_type = reduction
self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,

Loading…
Cancel
Save