diff --git a/demos/speech_ssl/run.sh b/demos/speech_ssl/run.sh index b71e0883a..ca94bc5cc 100644 --- a/demos/speech_ssl/run.sh +++ b/demos/speech_ssl/run.sh @@ -1,12 +1,10 @@ #!/bin/bash # audio download -wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav # to recognize text paddlespeech ssl --task asr --lang en --input ./en.wav # to get acoustic representation paddlespeech ssl --task vector --lang en --input ./en.wav - -README_cn \ No newline at end of file diff --git a/examples/librispeech/asr3/conf/preprocess.yaml b/examples/librispeech/asr3/conf/preprocess.yaml index 4a908a83b..724782ed6 100644 --- a/examples/librispeech/asr3/conf/preprocess.yaml +++ b/examples/librispeech/asr3/conf/preprocess.yaml @@ -1,4 +1,3 @@ process: # use raw audio - type: wav_process - dither: 0.0 diff --git a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml index c45bd692a..9ca535217 100644 --- a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml +++ b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml @@ -4,16 +4,21 @@ freeze_wav2vec2: True normalize_wav: True output_norm: True -dnn_blocks: 2 -dnn_neurons: 1024 -blank_id: 0 -ctc_dropout_rate: 0.0 +init_type: 'kaiming_uniform' # !Warning: need to convergence +enc: + input_shape: 1024 + dnn_blocks: 2 + dnn_neurons: 1024 + activation: True +ctc: + enc_n_units: 1024 + blank_id: 0 + dropout_rate: 0.0 wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams" ############################################ # Wav2Vec2.0 # ############################################ -vocab_size: 32 hidden_size: 1024 num_hidden_layers: 24 num_attention_heads: 16 @@ -54,9 +59,6 @@ diversity_loss_weight: 0.1 ctc_loss_reduction: "sum" ctc_zero_infinity: False use_weighted_layer_sum: False -pad_token_id: 0 -bos_token_id: 1 -eos_token_id: 2 add_adapter: False adapter_kernel_size: 3 adapter_stride: 2 @@ -78,7 +80,7 @@ unit_type: 'char' mean_std_filepath: "" preprocess_config: conf/preprocess.yaml sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs -batch_size: 10 # Different batch_size may cause large differences in results +batch_size: 6 # Different batch_size may cause large differences in results maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced minibatches: 0 # for debug @@ -112,11 +114,20 @@ model_optim_conf: lr: 0.9 epsilon: 1.0e-6 rho: 0.95 -scheduler: constantlr -scheduler_conf: +model_scheduler: constantlr +model_scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +wav2vec2_optim: adadelta +wav2vec2_optim_conf: + lr: 0.9 + epsilon: 1.0e-6 + rho: 0.95 +wav2vec2_scheduler: constantlr +wav2vec2_scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 checkpoint: kbest_n: 50 - latest_n: 5 + latest_n: 5 \ No newline at end of file diff --git a/examples/librispeech/asr3/local/train.sh b/examples/librispeech/asr3/local/train.sh index 6913ed17e..24776fd17 100644 --- a/examples/librispeech/asr3/local/train.sh +++ b/examples/librispeech/asr3/local/train.sh @@ -10,7 +10,8 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -ips=$3 +resume=$3 +ips=$4 if [ ! $ips ];then ips_config= @@ -21,7 +22,7 @@ fi mkdir -p exp # seed may break model convergence -seed=1998 +seed=1988 if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -34,13 +35,15 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} else python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} fi if [ ${seed} != 0 ]; then diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh index 3b1abb11b..05ad505c7 100644 --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -11,7 +11,7 @@ conf_path=conf/wav2vec2ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 -dict_path=data/lang_char/vocab.txt +resume= # xx e.g. 30 . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -28,7 +28,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -38,10 +38,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # greedy search decoder - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=${gpus} ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py index cba60cfdb..84812a2cf 100644 --- a/paddlespeech/audio/transform/spectrogram.py +++ b/paddlespeech/audio/transform/spectrogram.py @@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi(): class WavProcess(): - def __init__(self, dither=0.0): + def __init__(self): """ Args: dither (float): Dithering constant @@ -391,9 +391,7 @@ class WavProcess(): Returns: """ - self.dither = dither - - def __call__(self, x, train): + def __call__(self, x): """ Args: x (np.ndarray): shape (Ti,) @@ -405,10 +403,10 @@ class WavProcess(): Returns: np.ndarray: (T, D) """ - dither = self.dither if train else 0.0 if x.ndim != 1: raise ValueError("Not support x: [Time, Channel]") - waveform = np.expand_dims(x, -1) + waveform = x.astype("float32") / 32768.0 + waveform = np.expand_dims(waveform, -1) return waveform diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py index 3ae3a9e73..29e7ef552 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py @@ -34,9 +34,10 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + '--resume', type=str, default="", nargs="?", help='resume ckpt path.') args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) if args.config: diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 368a694c8..6a3321e46 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -15,6 +15,7 @@ import json import math import os +import re import time from collections import defaultdict from collections import OrderedDict @@ -62,6 +63,19 @@ class Wav2Vec2ASRTrainer(Trainer): self.avg_train_loss -= self.avg_train_loss / (batch_index + 1) self.avg_train_loss += loss / (batch_index + 1) + def before_train(self): + from_scratch = self.resume_or_scratch() + if from_scratch: + # scratch: save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") + + self.maybe_batch_sampler_step() + def train_batch(self, batch_index, batch, msg): train_conf = self.config start = time.time() @@ -69,14 +83,14 @@ class Wav2Vec2ASRTrainer(Trainer): # forward utt, wav, wavs_lens, target, target_lens = batch wavs_lens_rate = wavs_lens / wav.shape[1] - target_lens_rate = target_lens / target.shape[1] + wav = wav[:, :, 0] - if hasattr(train_conf, 'speech_augment'): + if hasattr(train_conf, 'audio_augment'): wav = self.speech_augmentation(wav, wavs_lens_rate) - loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) + + loss = self.model(wav, wavs_lens_rate, target, target_lens) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - # update self.avg_train_loss self.update_average(batch_index, float(loss)) @@ -98,11 +112,17 @@ class Wav2Vec2ASRTrainer(Trainer): # optimizer step old if (batch_index + 1) % train_conf.accum_grad == 0: - self.optimizer.step() - self.optimizer.clear_grad() - self.lr_scheduler.step() + self.model_optimizer.step() + self.model_optimizer.clear_grad() + if not train_conf.freeze_wav2vec2: + self.wav2vec2_optimizer.step() + self.wav2vec2_optimizer.clear_grad() + if self.config.model_scheduler != 'newbobscheduler': + self.model_lr_scheduler.step() + if self.config.wav2vec2_scheduler != 'newbobscheduler': + if not train_conf.freeze_wav2vec2: + self.wav2vec2_lr_scheduler.step() self.iteration += 1 - losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad} iteration_time = time.time() - start for k, v in losses_np.items(): @@ -114,7 +134,10 @@ class Wav2Vec2ASRTrainer(Trainer): if (batch_index + 1) % train_conf.accum_grad == 0: if dist.get_rank() == 0 and self.visualizer: losses_np_v = losses_np.copy() - losses_np_v.update({"lr": self.lr_scheduler()}) + losses_np_v.update({ + "model_lr": self.model_lr_scheduler(), + "wav2vec2_lr": self.wav2vec2_lr_scheduler() + }) for key, val in losses_np_v.items(): self.visualizer.add_scalar( tag='train/' + key, value=val, step=self.iteration - 1) @@ -131,11 +154,10 @@ class Wav2Vec2ASRTrainer(Trainer): for i, batch in enumerate(self.valid_loader): utt, wav, wavs_lens, target, target_lens = batch wavs_lens_rate = wavs_lens / wav.shape[1] - target_lens_rate = target_lens / target.shape[1] wav = wav[:, :, 0] - loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) + loss = self.model(wav, wavs_lens_rate, target, target_lens) - if paddle.isfinite(loss): + if math.isfinite(float(loss)): num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts @@ -160,6 +182,106 @@ class Wav2Vec2ASRTrainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts + @mp_tools.rank_zero_only + def save(self, tag=None, infos: dict=None): + """Save checkpoint (model parameters and optimizer states). + + Args: + tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. + infos (dict, optional): meta data to save. Defaults to None. + """ + + infos = infos if infos else dict() + infos.update({ + "epoch": self.epoch, + "model_lr": self.model_optimizer.get_lr(), + "wav2vec2_lr": self.wav2vec2_optimizer.get_lr() + }) + + checkpoint_path = os.path.join( + self.checkpoint_dir, + "{}".format(self.iteration if tag is None else tag)) + + model_dict = self.model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) + + model_opt_dict = self.model_optimizer.state_dict() + wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict() + + opt_dict = {'model': model_opt_dict, 'wav2vec2': wav2vec2_opt_dict} + + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + scheduler_dict = {} + + if self.config.model_scheduler == 'newbobscheduler': + scheduler_dict['model'] = self.model_lr_scheduler.save() + if self.config.wav2vec2_scheduler == 'newbobscheduler': + scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save() + if scheduler_dict: + scheduler_path = checkpoint_path + ".pdlrs" + paddle.save(scheduler_dict, scheduler_path) + logger.info("Saved scheduler state to {}".format(scheduler_path)) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) + + def resume_or_scratch(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + scratch = None + if self.args.resume: + # just restore ckpt + # lr will resotre from optimizer ckpt + resume_json_path = os.path.join(self.checkpoint_dir, + self.args.resume + '.json') + with open(resume_json_path, 'r') as f: + resume_json = json.load(f) + self.iteration = 0 + self.epoch = resume_json["epoch"] + + # resotre model from *.pdparams + params_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdparams' + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + # resotre optimizer from *.pdopt + optimizer_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdopt' + optimizer_dict = paddle.load(optimizer_path) + self.model_optimizer.set_state_dict(optimizer_dict['model']) + self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2']) + + # resotre lr_scheduler from *.pdlrs + scheduler_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdlrs' + if os.path.isfile(os.path.join(scheduler_path)): + scheduler_dict = paddle.load(scheduler_path) + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.load(scheduler_dict['model']) + if self.config.wav2vec2_scheduler == 'newbobscheduler': + self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2']) + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") + scratch = False + else: + self.iteration = 0 + self.epoch = 0 + scratch = True + logger.info("Init from scratch!") + return scratch + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! @@ -170,7 +292,6 @@ class Wav2Vec2ASRTrainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - if not self.use_streamdata: logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") @@ -187,7 +308,9 @@ class Wav2Vec2ASRTrainer(Trainer): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report("lr", self.lr_scheduler()) + report("model_lr", self.model_optimizer.get_lr()) + report("wav2vec2_lr", + self.wav2vec2_optimizer.get_lr()) self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) @@ -225,15 +348,25 @@ class Wav2Vec2ASRTrainer(Trainer): cv_loss = float(cv_loss) else: cv_loss = total_loss / num_seen_utts - logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) if self.visualizer: self.visualizer.add_scalar( tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar( - tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) - + tag='eval/model_lr', + value=self.model_lr_scheduler(), + step=self.epoch) + self.visualizer.add_scalar( + tag='eval/wav2vec2_lr', + value=self.wav2vec2_lr_scheduler(), + step=self.epoch) + + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.step(cv_loss) + if self.config.wav2vec2_scheduler == 'newbobscheduler': + if not self.config.freeze_wav2vec2: + self.wav2vec2_lr_scheduler.step(cv_loss) self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.new_epoch() @@ -268,14 +401,11 @@ class Wav2Vec2ASRTrainer(Trainer): model_conf.output_dim = self.test_loader.vocab_size model = Wav2vec2ASR.from_config(model_conf) - - # load pretrained wav2vec2 model params - wav2vec2_dict = paddle.load(config.wav2vec2_params_path) - model.wav2vec2.set_state_dict(wav2vec2_dict) + model_dict = paddle.load(config.wav2vec2_params_path) + model.wav2vec2.set_state_dict(model_dict) if self.parallel: model = paddle.DataParallel(model, find_unused_parameters=True) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) self.model = model @@ -290,46 +420,74 @@ class Wav2Vec2ASRTrainer(Trainer): return train_config = config - optim_type = train_config.model_optim - optim_conf = train_config.model_optim_conf - scheduler_type = train_config.scheduler - scheduler_conf = train_config.scheduler_conf - - scheduler_args = { - "learning_rate": optim_conf.lr, - "verbose": False, - "warmup_steps": scheduler_conf.warmup_steps, - "gamma": scheduler_conf.lr_decay, - "d_model": model_conf.dnn_neurons, - } - lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, - scheduler_args) + model_optim_type = train_config.model_optim + model_optim_conf = train_config.model_optim_conf + wav2vec2_optim_type = train_config.model_optim + wav2vec2_optim_conf = train_config.wav2vec2_optim_conf + + model_scheduler_type = train_config.model_scheduler + model_scheduler_conf = train_config.model_scheduler_conf + wav2vec2_scheduler_type = train_config.wav2vec2_scheduler + wav2vec2_scheduler_conf = train_config.wav2vec2_scheduler_conf + + model_scheduler_args = dict( + **{"learning_rate": model_optim_conf.lr, + "verbose": False}, **(dict(model_scheduler_conf))) + + wav2vec2_scheduler_args = dict( + **{"learning_rate": wav2vec2_optim_conf.lr, + "verbose": False}, **(dict(wav2vec2_scheduler_conf))) + + model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type, + model_scheduler_args) + wav2vec2_lr_scheduler = LRSchedulerFactory.from_args( + wav2vec2_scheduler_type, wav2vec2_scheduler_args) def optimizer_args( config, + optim_type, + optim_conf, parameters, lr_scheduler=None, ): train_config = config - optim_type = train_config.model_optim - optim_conf = train_config.model_optim_conf - scheduler_type = train_config.scheduler - scheduler_conf = train_config.scheduler_conf - return { - "grad_clip": train_config.global_grad_clip, - "learning_rate": lr_scheduler - if lr_scheduler else optim_conf.lr, - "epsilon": optim_conf.epsilon, - "rho": optim_conf.rho, - "parameters": parameters, - "beta1": 0.9 if optim_type == 'noam' else None, - "beat2": 0.98 if optim_type == 'noam' else None, - } - - optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) - optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) - - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + optim_arg = dict(optim_conf) + optim_arg.update({ + "grad_clip": + train_config.global_grad_clip, + "learning_rate": + lr_scheduler if lr_scheduler else optim_conf.lr, + "parameters": + parameters + }) + return optim_arg + + model_optimizer_args = optimizer_args(config, model_optim_type, + model_optim_conf, [{ + 'params': + model._layers.enc.parameters() + }, { + 'params': + model._layers.ctc.parameters() + }] if self.parallel else [{ + 'params': + model.enc.parameters() + }, { + 'params': + model.ctc.parameters() + }], model_lr_scheduler) + wav2vec2_optimizer_args = optimizer_args( + config, wav2vec2_optim_type, wav2vec2_optim_conf, + model._layers.wav2vec2.parameters() if self.parallel else + model.wav2vec2.parameters(), wav2vec2_lr_scheduler) + model_optimizer = OptimizerFactory.from_args(model_optim_type, + model_optimizer_args) + wav2vec2_optimizer = OptimizerFactory.from_args(wav2vec2_optim_type, + wav2vec2_optimizer_args) + + self.model_optimizer = model_optimizer + self.wav2vec2_optimizer = wav2vec2_optimizer + self.model_lr_scheduler = model_lr_scheduler + self.wav2vec2_lr_scheduler = wav2vec2_lr_scheduler logger.info("Setup optimizer/lr_scheduler!") diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index cfd8f507e..f47f2b0ab 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -1,23 +1,12 @@ -# Authors -# * Elena Rastorgueva 2020 -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/VanillaNN.py). +"""Vanilla Neural Network for simple tests. +Authors +* Elena Rastorgueva 2020 +""" import paddle from paddlespeech.s2t.models.wav2vec2.modules import containers from paddlespeech.s2t.models.wav2vec2.modules import linear +from paddlespeech.s2t.models.wav2vec2.modules.normalization import BatchNorm1d class VanillaNN(containers.Sequential): @@ -39,18 +28,34 @@ class VanillaNN(containers.Sequential): paddle.shape([10, 120, 512]) """ - def __init__( - self, - input_shape, - activation=paddle.nn.LeakyReLU, - dnn_blocks=2, - dnn_neurons=512, ): - super().__init__(input_shape=input_shape) + def __init__(self, + input_shape, + dnn_blocks=2, + dnn_neurons=512, + activation=True, + normalization=False, + dropout_rate=0.0): + super().__init__(input_shape=[None, None, input_shape]) + + if not isinstance(dropout_rate, list): + dropout_rate = [dropout_rate] * dnn_blocks + else: + assert len( + dropout_rate + ) == dnn_blocks, "len(dropout_rate) must equal to dnn_blocks" for block_index in range(dnn_blocks): self.append( linear.Linear, n_neurons=dnn_neurons, - bias=True, + bias_attr=None, layer_name="linear", ) - self.append(activation(), layer_name="act") + if normalization: + self.append( + BatchNorm1d, input_size=dnn_neurons, layer_name='bn') + if activation: + self.append(paddle.nn.LeakyReLU(), layer_name="act") + self.append( + paddle.nn.Dropout(), + p=dropout_rate[block_index], + layer_name='dropout') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/containers.py b/paddlespeech/s2t/models/wav2vec2/modules/containers.py index 180d0bd32..6a6b94e95 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/containers.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/containers.py @@ -141,5 +141,4 @@ class Sequential(paddle.nn.LayerDict): x = layer(x) if isinstance(x, tuple): x = x[0] - return x diff --git a/paddlespeech/s2t/models/wav2vec2/modules/linear.py b/paddlespeech/s2t/models/wav2vec2/modules/linear.py index adae4514a..0b40f4362 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/linear.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/linear.py @@ -1,20 +1,8 @@ -# Authors -# * Mirco Ravanelli 2020 -# * Davide Borra 2021 -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/linear.py). +"""Library implementing linear transformation. +Authors + * Mirco Ravanelli 2020 + * Davide Borra 2021 +""" import logging import paddle @@ -53,7 +41,7 @@ class Linear(paddle.nn.Layer): n_neurons, input_shape=None, input_size=None, - bias=True, + bias_attr=None, combine_dims=False, ): super().__init__() self.combine_dims = combine_dims @@ -67,7 +55,7 @@ class Linear(paddle.nn.Layer): input_size = input_shape[2] * input_shape[3] # Weights are initialized following paddle approach - self.w = align.Linear(input_size, n_neurons, bias_attr=bias) + self.w = align.Linear(input_size, n_neurons, bias_attr=bias_attr) def forward(self, x): """Returns the linear transformation of input tensor. diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py index e484fff68..5670cb531 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py @@ -1120,9 +1120,6 @@ class Wav2Vec2ConfigPure(): self.output_hidden_states = False self.use_return_dict = True - self.pad_token_id = config.pad_token_id - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id self.hidden_size = config.hidden_size self.feat_extract_norm = config.feat_extract_norm self.feat_extract_activation = config.feat_extract_activation @@ -1145,7 +1142,6 @@ class Wav2Vec2ConfigPure(): self.layerdrop = config.layerdrop self.layer_norm_eps = config.layer_norm_eps self.initializer_range = config.initializer_range - self.vocab_size = config.vocab_size self.do_stable_layer_norm = config.do_stable_layer_norm self.use_weighted_layer_sum = config.use_weighted_layer_sum diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py index ac9bf45db..9224549a4 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py @@ -639,6 +639,170 @@ class DropChunk(nn.Layer): return dropped_waveform +class SpecAugment(paddle.nn.Layer): + """An implementation of the SpecAugment algorithm. + Reference: + https://arxiv.org/abs/1904.08779 + Arguments + --------- + time_warp : bool + Whether applying time warping. + time_warp_window : int + Time warp window. + time_warp_mode : str + Interpolation mode for time warping (default "bicubic"). + freq_mask : bool + Whether applying freq mask. + freq_mask_width : int or tuple + Freq mask width range. + n_freq_mask : int + Number of freq mask. + time_mask : bool + Whether applying time mask. + time_mask_width : int or tuple + Time mask width range. + n_time_mask : int + Number of time mask. + replace_with_zero : bool + If True, replace masked value with 0, else replace masked value with mean of the input tensor. + Example + ------- + >>> aug = SpecAugment() + >>> a = paddle.rand([8, 120, 80]) + >>> a = aug(a) + >>> print(a.shape) + paddle.Size([8, 120, 80]) + """ + + def __init__( + self, + time_warp=True, + time_warp_window=5, + time_warp_mode="bicubic", + freq_mask=True, + freq_mask_width=(0, 20), + n_freq_mask=2, + time_mask=True, + time_mask_width=(0, 100), + n_time_mask=2, + replace_with_zero=True, ): + super().__init__() + assert ( + time_warp or freq_mask or time_mask + ), "at least one of time_warp, time_mask, or freq_mask should be applied" + + self.apply_time_warp = time_warp + self.time_warp_window = time_warp_window + self.time_warp_mode = time_warp_mode + + self.freq_mask = freq_mask + if isinstance(freq_mask_width, int): + freq_mask_width = (0, freq_mask_width) + self.freq_mask_width = freq_mask_width + self.n_freq_mask = n_freq_mask + + self.time_mask = time_mask + if isinstance(time_mask_width, int): + time_mask_width = (0, time_mask_width) + self.time_mask_width = time_mask_width + self.n_time_mask = n_time_mask + + self.replace_with_zero = replace_with_zero + + def forward(self, x): + """Takes in input a tensors and returns an augmented one.""" + if self.apply_time_warp: + x = self.time_warp(x) + if self.freq_mask: + x = self.mask_along_axis(x, dim=2) + if self.time_mask: + x = self.mask_along_axis(x, dim=1) + return x + + def time_warp(self, x): + """Time warping with paddle.nn.functional.interpolate""" + original_size = x.shape + window = self.time_warp_window + + # 2d interpolation requires 4D or higher dimension tensors + # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) + if x.dim() == 3: + x = x.unsqueeze(1) + + time = x.shape[2] + if time - window <= window: + return x.view(*original_size) + + # compute center and corresponding window + c = paddle.randint(window, time - window, (1, ))[0] + w = paddle.randint(c - window, c + window, (1, ))[0] + 1 + # c = 5 + # w = 10 + left = paddle.nn.functional.interpolate( + x[:, :, :c], + (w, x.shape[3]), + mode=self.time_warp_mode, + align_corners=True, ) + right = paddle.nn.functional.interpolate( + x[:, :, c:], + (time - w, x.shape[3]), + mode=self.time_warp_mode, + align_corners=True, ) + + x[:, :, :w] = left + x[:, :, w:] = right + return x.view(*original_size) + + def mask_along_axis(self, x, dim): + """Mask along time or frequency axis. + Arguments + --------- + x : tensor + Input tensor. + dim : int + Corresponding dimension to mask. + """ + original_size = x.shape + if x.dim() == 4: + x = x.view(-1, x.shape[2], x.shape[3]) + + batch, time, fea = x.shape + + if dim == 1: + D = time + n_mask = self.n_time_mask + width_range = self.time_mask_width + else: + D = fea + n_mask = self.n_freq_mask + width_range = self.freq_mask_width + + mask_len = paddle.randint(width_range[0], width_range[1], + (batch, n_mask)).unsqueeze(2) + + mask_pos = paddle.randint(0, max(1, D - mask_len.max()), + (batch, n_mask)).unsqueeze(2) + + # compute masks + arange = paddle.arange(end=D).view(1, 1, -1) + mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) + mask = mask.any(axis=1) + + if dim == 1: + mask = mask.unsqueeze(2) + else: + mask = mask.unsqueeze(1) + + if self.replace_with_zero: + val = 0.0 + else: + val = x.mean() + # same to x.masked_fill_(mask, val) + y = paddle.full(x.shape, val, x.dtype) + x = paddle.where(mask, y, x) + return x.view(*original_size) + + class TimeDomainSpecAugment(nn.Layer): """A time-domain approximation of the SpecAugment algorithm. This augmentation module implements three augmentations in diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 09b254c51..eda188da5 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -23,7 +23,9 @@ import paddle.nn.functional as F from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.utility import log_add @@ -31,42 +33,41 @@ from paddlespeech.s2t.utils.utility import log_add class Wav2vec2ASR(nn.Layer): def __init__(self, config: dict): super().__init__() + init_type = config.get("init_type", None) + with DefaultInitializerContext(init_type): + self.config = config + wav2vec2_config = Wav2Vec2ConfigPure(config) + wav2vec2 = Wav2Vec2Model(wav2vec2_config) + self.normalize_wav = config.normalize_wav + self.output_norm = config.output_norm + if hasattr(config, 'spec_augment'): + self.spec_augment = SpecAugment(**config.spec_augment) - wav2vec2_config = Wav2Vec2ConfigPure(config) - wav2vec2 = Wav2Vec2Model(wav2vec2_config) - self.normalize_wav = config.normalize_wav - self.output_norm = config.output_norm - if config.freeze_wav2vec2: - wav2vec2.eval() - for parm in wav2vec2.parameters(): - parm.trainable = False - self.wav2vec2 = wav2vec2 - self.enc = VanillaNN( - input_shape=[None, None, wav2vec2_config.hidden_size], - activation=nn.LeakyReLU, - dnn_blocks=config.dnn_blocks, - dnn_neurons=config.dnn_neurons) - self.ctc = CTC(odim=config.output_dim, - enc_n_units=config.dnn_neurons, - blank_id=config.blank_id, - dropout_rate=config.ctc_dropout_rate, - reduction='mean') + if config.freeze_wav2vec2: + wav2vec2.eval() + for parm in wav2vec2.parameters(): + parm.trainable = False + self.wav2vec2 = wav2vec2 + self.enc = VanillaNN(**config.enc) + self.ctc = CTC(**config.ctc, + odim=config.output_dim, + batch_average=False, + reduction='mean') - def forward(self, wav, wavs_lens_rate, target, target_lens_rate): + def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: - wav = F.layer_norm(wav, wav.shape[1:]) + wav = F.layer_norm(wav, wav.shape) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: - out = F.layer_norm(out, out.shape[1:]) - feats = out - + out = F.layer_norm(out, out.shape) + if self.train and hasattr(self.config, 'spec_augment'): + feats = self.spec_augment(out) + else: + feats = out x = self.enc(feats) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) - target_lens = (target_lens_rate * - target.shape[1]).round().astype(paddle.int64) - ctc_loss = self.ctc(x, x_lens, target, target_lens) return ctc_loss diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index b22f7ef85..53c756ce3 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -17,6 +17,7 @@ from typing import Dict from typing import Text from typing import Union +import paddle from paddle.optimizer.lr import LRScheduler from typeguard import check_argument_types @@ -107,6 +108,125 @@ class ConstantLR(LRScheduler): return self.base_lr +@register_scheduler +class NewBobScheduler(LRScheduler): + """Scheduler with new-bob technique, used for LR annealing. + + The learning rate is annealed based on the validation performance. + In particular: if (past_loss-current_loss)/past_loss< impr_threshold: + lr=lr * annealing_factor. + + Arguments + --------- + initial_value : float + The initial hyperparameter value. + annealing_factor : float + It is annealing factor used in new_bob strategy. + improvement_threshold : float + It is the improvement rate between losses used to perform learning + annealing in new_bob strategy. + patient : int + When the annealing condition is violated patient times, + the learning rate is finally reduced. + + Example + ------- + >>> scheduler = NewBobScheduler(initial_value=1.0) + >>> scheduler(metric_value=10.0) + (1.0, 1.0) + >>> scheduler(metric_value=2.0) + (1.0, 1.0) + >>> scheduler(metric_value=2.5) + (1.0, 0.5) + """ + + def __init__( + self, + learning_rate, + last_epoch=-1, + verbose=False, + annealing_factor=0.5, + improvement_threshold=0.0025, + patient=0, ): + self.hyperparam_value = learning_rate + self.annealing_factor = annealing_factor + self.improvement_threshold = improvement_threshold + self.patient = patient + self.metric_values = [] + self.current_patient = self.patient + super().__init__(learning_rate, last_epoch, verbose) + + def step(self, metric_value=None): + """ + + ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . + The new learning rate will take effect on next ``optimizer.step`` . + + Args: + epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. + + Returns: + None + """ + if metric_value is None: + self.last_epoch += 1 + self.last_lr = self.hyperparam_value + else: + self.last_epoch += 1 + self.last_lr = self.get_lr(metric_value) + + if self.verbose: + print('Epoch {}: {} set learning rate to {}.'.format( + self.last_epoch, self.__class__.__name__, self.last_lr)) + + def get_lr(self, metric_value): + """Returns the current and new value for the hyperparameter. + + Arguments + --------- + metric_value : int + A number for determining whether to change the hyperparameter value. + """ + new_value = self.hyperparam_value + if len(self.metric_values) > 0: + prev_metric = self.metric_values[-1] + # Update value if improvement too small and patience is 0 + if prev_metric == 0: # Prevent division by zero + improvement = 0 + else: + improvement = (prev_metric - metric_value) / prev_metric + if improvement < self.improvement_threshold: + if self.current_patient == 0: + new_value *= self.annealing_factor + self.current_patient = self.patient + else: + self.current_patient -= 1 + + # Store relevant info + self.metric_values.append(metric_value) + self.hyperparam_value = new_value + + return new_value + + def save(self): + """Saves the current metrics on the specified path.""" + data = { + "current_epoch_index": self.last_epoch, + "hyperparam_value": self.hyperparam_value, + "metric_values": self.metric_values, + "current_patient": self.current_patient + } + return data + + def load(self, data): + """Loads the needed information.""" + data = paddle.load(data) + self.last_epoch = data["current_epoch_index"] + self.hyperparam_value = data["hyperparam_value"] + self.metric_values = data["metric_values"] + self.current_patient = data["current_patient"] + + def dynamic_import_scheduler(module): """Import Scheduler class dynamically.