format wav2vec2 demo

pull/2518/head
tianhao zhang 2 years ago
parent 7bee9d807f
commit 3d994f5c23

@ -1,4 +1,4 @@
process: process:
# extract kaldi fbank from PCM # use raw audio
- type: wav_process - type: wav_process
dither: 0.1 dither: 0.0

@ -1,11 +1,4 @@
decode_batch_size: 1 decode_batch_size: 1
error_rate_type: wer error_rate_type: wer
decoding_method: ctc_greedy_search # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
beam_size: 10 beam_size: 10
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -36,9 +36,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
avg.sh best exp/${ckpt}/checkpoints ${avg_num} avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# attetion resocre decoder # greedy search decoder
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi

@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi():
class WavProcess(): class WavProcess():
def __init__(self, dither=0.1): def __init__(self, dither=0.0):
""" """
Args: Args:
dither (float): Dithering constant dither (float): Dithering constant

@ -20,8 +20,6 @@ from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args): def main_sp(config, args):
exp = Tester(config, args) exp = Tester(config, args)

@ -25,9 +25,7 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
@ -300,7 +298,6 @@ class Wav2Vec2ASRTrainer(Trainer):
"epsilon": optim_conf.epsilon, "epsilon": optim_conf.epsilon,
"rho": optim_conf.rho, "rho": optim_conf.rho,
"parameters": parameters, "parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None, "beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None, "beat2": 0.98 if optim_type == 'noam' else None,
} }

@ -39,7 +39,7 @@ class Wav2vec2ASR(nn.Layer):
enc_n_units=config.dnn_neurons, enc_n_units=config.dnn_neurons,
blank_id=config.blank_id, blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate, dropout_rate=config.ctc_dropout_rate,
reduction=True) reduction='mean')
def forward(self, wav, wavs_lens_rate, target, target_lens_rate): def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav: if self.normalize_wav:

@ -53,7 +53,7 @@ class CTCDecoderBase(nn.Layer):
enc_n_units, enc_n_units,
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True, reduction: Union[str, bool]=True,
batch_average: bool=True, batch_average: bool=True,
grad_norm_type: Union[str, None]=None): grad_norm_type: Union[str, None]=None):
"""CTC decoder """CTC decoder
@ -73,7 +73,10 @@ class CTCDecoderBase(nn.Layer):
self.odim = odim self.odim = odim
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = Linear(enc_n_units, self.odim) self.ctc_lo = Linear(enc_n_units, self.odim)
if isinstance(reduction, bool):
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
else:
reduction_type = reduction
self.criterion = CTCLoss( self.criterion = CTCLoss(
blank=self.blank_id, blank=self.blank_id,
reduction=reduction_type, reduction=reduction_type,

Loading…
Cancel
Save