wav2vec2 demo update: support different optimizer and lr_schedular, align mdoel, update input type, test=asr

pull/2658/head
tianhao zhang 3 years ago
parent 4cdfa5ccfd
commit 4f6b076a0a

@ -1,12 +1,10 @@
#!/bin/bash
# audio download
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# to recognize text
paddlespeech ssl --task asr --lang en --input ./en.wav
# to get acoustic representation
paddlespeech ssl --task vector --lang en --input ./en.wav
README_cn

@ -1,4 +1,3 @@
process:
# use raw audio
- type: wav_process
dither: 0.0

@ -4,16 +4,21 @@
freeze_wav2vec2: True
normalize_wav: True
output_norm: True
dnn_blocks: 2
dnn_neurons: 1024
blank_id: 0
ctc_dropout_rate: 0.0
init_type: 'kaiming_uniform' # !Warning: need to convergence
enc:
input_shape: 1024
dnn_blocks: 2
dnn_neurons: 1024
activation: True
ctc:
enc_n_units: 1024
blank_id: 0
dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
############################################
# Wav2Vec2.0 #
############################################
vocab_size: 32
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
@ -54,9 +59,6 @@ diversity_loss_weight: 0.1
ctc_loss_reduction: "sum"
ctc_zero_infinity: False
use_weighted_layer_sum: False
pad_token_id: 0
bos_token_id: 1
eos_token_id: 2
add_adapter: False
adapter_kernel_size: 3
adapter_stride: 2
@ -78,7 +80,7 @@ unit_type: 'char'
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs
batch_size: 10 # Different batch_size may cause large differences in results
batch_size: 6 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
@ -112,8 +114,17 @@ model_optim_conf:
lr: 0.9
epsilon: 1.0e-6
rho: 0.95
scheduler: constantlr
scheduler_conf:
model_scheduler: constantlr
model_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
wav2vec2_optim: adadelta
wav2vec2_optim_conf:
lr: 0.9
epsilon: 1.0e-6
rho: 0.95
wav2vec2_scheduler: constantlr
wav2vec2_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1

@ -10,7 +10,8 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
ips=$3
resume=$3
ips=$4
if [ ! $ips ];then
ips_config=
@ -21,7 +22,7 @@ fi
mkdir -p exp
# seed may break model convergence
seed=1998
seed=1988
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
@ -34,13 +35,15 @@ python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
--seed ${seed} \
--resume ${resume}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
--seed ${seed} \
--resume ${resume}
fi
if [ ${seed} != 0 ]; then

@ -11,7 +11,7 @@ conf_path=conf/wav2vec2ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
dict_path=data/lang_char/vocab.txt
resume= # xx e.g. 30
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -28,7 +28,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -38,10 +38,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# greedy search decoder
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=${gpus} ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi

@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi():
class WavProcess():
def __init__(self, dither=0.0):
def __init__(self):
"""
Args:
dither (float): Dithering constant
@ -391,9 +391,7 @@ class WavProcess():
Returns:
"""
self.dither = dither
def __call__(self, x, train):
def __call__(self, x):
"""
Args:
x (np.ndarray): shape (Ti,)
@ -405,10 +403,10 @@ class WavProcess():
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = np.expand_dims(x, -1)
waveform = x.astype("float32") / 32768.0
waveform = np.expand_dims(waveform, -1)
return waveform

@ -34,9 +34,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
'--resume', type=str, default="", nargs="?", help='resume ckpt path.')
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:

@ -15,6 +15,7 @@
import json
import math
import os
import re
import time
from collections import defaultdict
from collections import OrderedDict
@ -62,6 +63,19 @@ class Wav2Vec2ASRTrainer(Trainer):
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
self.avg_train_loss += loss / (batch_index + 1)
def before_train(self):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
logger.info(
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
self.maybe_batch_sampler_step()
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
@ -69,14 +83,14 @@ class Wav2Vec2ASRTrainer(Trainer):
# forward
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0]
if hasattr(train_conf, 'speech_augment'):
if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
# update self.avg_train_loss
self.update_average(batch_index, float(loss))
@ -98,11 +112,17 @@ class Wav2Vec2ASRTrainer(Trainer):
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.model_optimizer.step()
self.model_optimizer.clear_grad()
if not train_conf.freeze_wav2vec2:
self.wav2vec2_optimizer.step()
self.wav2vec2_optimizer.clear_grad()
if self.config.model_scheduler != 'newbobscheduler':
self.model_lr_scheduler.step()
if self.config.wav2vec2_scheduler != 'newbobscheduler':
if not train_conf.freeze_wav2vec2:
self.wav2vec2_lr_scheduler.step()
self.iteration += 1
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
@ -114,7 +134,10 @@ class Wav2Vec2ASRTrainer(Trainer):
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
losses_np_v.update({
"model_lr": self.model_lr_scheduler(),
"wav2vec2_lr": self.wav2vec2_lr_scheduler()
})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@ -131,11 +154,10 @@ class Wav2Vec2ASRTrainer(Trainer):
for i, batch in enumerate(self.valid_loader):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens)
if paddle.isfinite(loss):
if math.isfinite(float(loss)):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
@ -160,6 +182,106 @@ class Wav2Vec2ASRTrainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
@mp_tools.rank_zero_only
def save(self, tag=None, infos: dict=None):
"""Save checkpoint (model parameters and optimizer states).
Args:
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
infos (dict, optional): meta data to save. Defaults to None.
"""
infos = infos if infos else dict()
infos.update({
"epoch": self.epoch,
"model_lr": self.model_optimizer.get_lr(),
"wav2vec2_lr": self.wav2vec2_optimizer.get_lr()
})
checkpoint_path = os.path.join(
self.checkpoint_dir,
"{}".format(self.iteration if tag is None else tag))
model_dict = self.model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path))
model_opt_dict = self.model_optimizer.state_dict()
wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict()
opt_dict = {'model': model_opt_dict, 'wav2vec2': wav2vec2_opt_dict}
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
logger.info("Saved optimzier state to {}".format(optimizer_path))
scheduler_dict = {}
if self.config.model_scheduler == 'newbobscheduler':
scheduler_dict['model'] = self.model_lr_scheduler.save()
if self.config.wav2vec2_scheduler == 'newbobscheduler':
scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save()
if scheduler_dict:
scheduler_path = checkpoint_path + ".pdlrs"
paddle.save(scheduler_dict, scheduler_path)
logger.info("Saved scheduler state to {}".format(scheduler_path))
info_path = re.sub('.pdparams$', '.json', params_path)
infos = {} if infos is None else infos
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
scratch = None
if self.args.resume:
# just restore ckpt
# lr will resotre from optimizer ckpt
resume_json_path = os.path.join(self.checkpoint_dir,
self.args.resume + '.json')
with open(resume_json_path, 'r') as f:
resume_json = json.load(f)
self.iteration = 0
self.epoch = resume_json["epoch"]
# resotre model from *.pdparams
params_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdparams'
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
# resotre optimizer from *.pdopt
optimizer_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdopt'
optimizer_dict = paddle.load(optimizer_path)
self.model_optimizer.set_state_dict(optimizer_dict['model'])
self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2'])
# resotre lr_scheduler from *.pdlrs
scheduler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdlrs'
if os.path.isfile(os.path.join(scheduler_path)):
scheduler_dict = paddle.load(scheduler_path)
if self.config.model_scheduler == 'newbobscheduler':
self.model_lr_scheduler.load(scheduler_dict['model'])
if self.config.wav2vec2_scheduler == 'newbobscheduler':
self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2'])
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
scratch = False
else:
self.iteration = 0
self.epoch = 0
scratch = True
logger.info("Init from scratch!")
return scratch
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
@ -170,7 +292,6 @@ class Wav2Vec2ASRTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self.before_train()
if not self.use_streamdata:
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
@ -187,7 +308,9 @@ class Wav2Vec2ASRTrainer(Trainer):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
report("model_lr", self.model_optimizer.get_lr())
report("wav2vec2_lr",
self.wav2vec2_optimizer.get_lr())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
@ -225,15 +348,25 @@ class Wav2Vec2ASRTrainer(Trainer):
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
tag='eval/model_lr',
value=self.model_lr_scheduler(),
step=self.epoch)
self.visualizer.add_scalar(
tag='eval/wav2vec2_lr',
value=self.wav2vec2_lr_scheduler(),
step=self.epoch)
if self.config.model_scheduler == 'newbobscheduler':
self.model_lr_scheduler.step(cv_loss)
if self.config.wav2vec2_scheduler == 'newbobscheduler':
if not self.config.freeze_wav2vec2:
self.wav2vec2_lr_scheduler.step(cv_loss)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
@ -268,14 +401,11 @@ class Wav2Vec2ASRTrainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = Wav2vec2ASR.from_config(model_conf)
# load pretrained wav2vec2 model params
wav2vec2_dict = paddle.load(config.wav2vec2_params_path)
model.wav2vec2.set_state_dict(wav2vec2_dict)
model_dict = paddle.load(config.wav2vec2_params_path)
model.wav2vec2.set_state_dict(model_dict)
if self.parallel:
model = paddle.DataParallel(model, find_unused_parameters=True)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
@ -290,46 +420,74 @@ class Wav2Vec2ASRTrainer(Trainer):
return
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.dnn_neurons,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
model_optim_type = train_config.model_optim
model_optim_conf = train_config.model_optim_conf
wav2vec2_optim_type = train_config.model_optim
wav2vec2_optim_conf = train_config.wav2vec2_optim_conf
model_scheduler_type = train_config.model_scheduler
model_scheduler_conf = train_config.model_scheduler_conf
wav2vec2_scheduler_type = train_config.wav2vec2_scheduler
wav2vec2_scheduler_conf = train_config.wav2vec2_scheduler_conf
model_scheduler_args = dict(
**{"learning_rate": model_optim_conf.lr,
"verbose": False}, **(dict(model_scheduler_conf)))
wav2vec2_scheduler_args = dict(
**{"learning_rate": wav2vec2_optim_conf.lr,
"verbose": False}, **(dict(wav2vec2_scheduler_conf)))
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
model_scheduler_args)
wav2vec2_lr_scheduler = LRSchedulerFactory.from_args(
wav2vec2_scheduler_type, wav2vec2_scheduler_args)
def optimizer_args(
config,
optim_type,
optim_conf,
parameters,
lr_scheduler=None, ):
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"epsilon": optim_conf.epsilon,
"rho": optim_conf.rho,
"parameters": parameters,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
optim_arg = dict(optim_conf)
optim_arg.update({
"grad_clip":
train_config.global_grad_clip,
"learning_rate":
lr_scheduler if lr_scheduler else optim_conf.lr,
"parameters":
parameters
})
return optim_arg
model_optimizer_args = optimizer_args(config, model_optim_type,
model_optim_conf, [{
'params':
model._layers.enc.parameters()
}, {
'params':
model._layers.ctc.parameters()
}] if self.parallel else [{
'params':
model.enc.parameters()
}, {
'params':
model.ctc.parameters()
}], model_lr_scheduler)
wav2vec2_optimizer_args = optimizer_args(
config, wav2vec2_optim_type, wav2vec2_optim_conf,
model._layers.wav2vec2.parameters() if self.parallel else
model.wav2vec2.parameters(), wav2vec2_lr_scheduler)
model_optimizer = OptimizerFactory.from_args(model_optim_type,
model_optimizer_args)
wav2vec2_optimizer = OptimizerFactory.from_args(wav2vec2_optim_type,
wav2vec2_optimizer_args)
self.model_optimizer = model_optimizer
self.wav2vec2_optimizer = wav2vec2_optimizer
self.model_lr_scheduler = model_lr_scheduler
self.wav2vec2_lr_scheduler = wav2vec2_lr_scheduler
logger.info("Setup optimizer/lr_scheduler!")

@ -1,23 +1,12 @@
# Authors
# * Elena Rastorgueva 2020
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/VanillaNN.py).
"""Vanilla Neural Network for simple tests.
Authors
* Elena Rastorgueva 2020
"""
import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear
from paddlespeech.s2t.models.wav2vec2.modules.normalization import BatchNorm1d
class VanillaNN(containers.Sequential):
@ -39,18 +28,34 @@ class VanillaNN(containers.Sequential):
paddle.shape([10, 120, 512])
"""
def __init__(
self,
def __init__(self,
input_shape,
activation=paddle.nn.LeakyReLU,
dnn_blocks=2,
dnn_neurons=512, ):
super().__init__(input_shape=input_shape)
dnn_neurons=512,
activation=True,
normalization=False,
dropout_rate=0.0):
super().__init__(input_shape=[None, None, input_shape])
if not isinstance(dropout_rate, list):
dropout_rate = [dropout_rate] * dnn_blocks
else:
assert len(
dropout_rate
) == dnn_blocks, "len(dropout_rate) must equal to dnn_blocks"
for block_index in range(dnn_blocks):
self.append(
linear.Linear,
n_neurons=dnn_neurons,
bias=True,
bias_attr=None,
layer_name="linear", )
self.append(activation(), layer_name="act")
if normalization:
self.append(
BatchNorm1d, input_size=dnn_neurons, layer_name='bn')
if activation:
self.append(paddle.nn.LeakyReLU(), layer_name="act")
self.append(
paddle.nn.Dropout(),
p=dropout_rate[block_index],
layer_name='dropout')

@ -141,5 +141,4 @@ class Sequential(paddle.nn.LayerDict):
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x

@ -1,20 +1,8 @@
# Authors
# * Mirco Ravanelli 2020
# * Davide Borra 2021
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/linear.py).
"""Library implementing linear transformation.
Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import logging
import paddle
@ -53,7 +41,7 @@ class Linear(paddle.nn.Layer):
n_neurons,
input_shape=None,
input_size=None,
bias=True,
bias_attr=None,
combine_dims=False, ):
super().__init__()
self.combine_dims = combine_dims
@ -67,7 +55,7 @@ class Linear(paddle.nn.Layer):
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following paddle approach
self.w = align.Linear(input_size, n_neurons, bias_attr=bias)
self.w = align.Linear(input_size, n_neurons, bias_attr=bias_attr)
def forward(self, x):
"""Returns the linear transformation of input tensor.

@ -1120,9 +1120,6 @@ class Wav2Vec2ConfigPure():
self.output_hidden_states = False
self.use_return_dict = True
self.pad_token_id = config.pad_token_id
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.hidden_size = config.hidden_size
self.feat_extract_norm = config.feat_extract_norm
self.feat_extract_activation = config.feat_extract_activation
@ -1145,7 +1142,6 @@ class Wav2Vec2ConfigPure():
self.layerdrop = config.layerdrop
self.layer_norm_eps = config.layer_norm_eps
self.initializer_range = config.initializer_range
self.vocab_size = config.vocab_size
self.do_stable_layer_norm = config.do_stable_layer_norm
self.use_weighted_layer_sum = config.use_weighted_layer_sum

@ -639,6 +639,170 @@ class DropChunk(nn.Layer):
return dropped_waveform
class SpecAugment(paddle.nn.Layer):
"""An implementation of the SpecAugment algorithm.
Reference:
https://arxiv.org/abs/1904.08779
Arguments
---------
time_warp : bool
Whether applying time warping.
time_warp_window : int
Time warp window.
time_warp_mode : str
Interpolation mode for time warping (default "bicubic").
freq_mask : bool
Whether applying freq mask.
freq_mask_width : int or tuple
Freq mask width range.
n_freq_mask : int
Number of freq mask.
time_mask : bool
Whether applying time mask.
time_mask_width : int or tuple
Time mask width range.
n_time_mask : int
Number of time mask.
replace_with_zero : bool
If True, replace masked value with 0, else replace masked value with mean of the input tensor.
Example
-------
>>> aug = SpecAugment()
>>> a = paddle.rand([8, 120, 80])
>>> a = aug(a)
>>> print(a.shape)
paddle.Size([8, 120, 80])
"""
def __init__(
self,
time_warp=True,
time_warp_window=5,
time_warp_mode="bicubic",
freq_mask=True,
freq_mask_width=(0, 20),
n_freq_mask=2,
time_mask=True,
time_mask_width=(0, 100),
n_time_mask=2,
replace_with_zero=True, ):
super().__init__()
assert (
time_warp or freq_mask or time_mask
), "at least one of time_warp, time_mask, or freq_mask should be applied"
self.apply_time_warp = time_warp
self.time_warp_window = time_warp_window
self.time_warp_mode = time_warp_mode
self.freq_mask = freq_mask
if isinstance(freq_mask_width, int):
freq_mask_width = (0, freq_mask_width)
self.freq_mask_width = freq_mask_width
self.n_freq_mask = n_freq_mask
self.time_mask = time_mask
if isinstance(time_mask_width, int):
time_mask_width = (0, time_mask_width)
self.time_mask_width = time_mask_width
self.n_time_mask = n_time_mask
self.replace_with_zero = replace_with_zero
def forward(self, x):
"""Takes in input a tensors and returns an augmented one."""
if self.apply_time_warp:
x = self.time_warp(x)
if self.freq_mask:
x = self.mask_along_axis(x, dim=2)
if self.time_mask:
x = self.mask_along_axis(x, dim=1)
return x
def time_warp(self, x):
"""Time warping with paddle.nn.functional.interpolate"""
original_size = x.shape
window = self.time_warp_window
# 2d interpolation requires 4D or higher dimension tensors
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
if x.dim() == 3:
x = x.unsqueeze(1)
time = x.shape[2]
if time - window <= window:
return x.view(*original_size)
# compute center and corresponding window
c = paddle.randint(window, time - window, (1, ))[0]
w = paddle.randint(c - window, c + window, (1, ))[0] + 1
# c = 5
# w = 10
left = paddle.nn.functional.interpolate(
x[:, :, :c],
(w, x.shape[3]),
mode=self.time_warp_mode,
align_corners=True, )
right = paddle.nn.functional.interpolate(
x[:, :, c:],
(time - w, x.shape[3]),
mode=self.time_warp_mode,
align_corners=True, )
x[:, :, :w] = left
x[:, :, w:] = right
return x.view(*original_size)
def mask_along_axis(self, x, dim):
"""Mask along time or frequency axis.
Arguments
---------
x : tensor
Input tensor.
dim : int
Corresponding dimension to mask.
"""
original_size = x.shape
if x.dim() == 4:
x = x.view(-1, x.shape[2], x.shape[3])
batch, time, fea = x.shape
if dim == 1:
D = time
n_mask = self.n_time_mask
width_range = self.time_mask_width
else:
D = fea
n_mask = self.n_freq_mask
width_range = self.freq_mask_width
mask_len = paddle.randint(width_range[0], width_range[1],
(batch, n_mask)).unsqueeze(2)
mask_pos = paddle.randint(0, max(1, D - mask_len.max()),
(batch, n_mask)).unsqueeze(2)
# compute masks
arange = paddle.arange(end=D).view(1, 1, -1)
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
mask = mask.any(axis=1)
if dim == 1:
mask = mask.unsqueeze(2)
else:
mask = mask.unsqueeze(1)
if self.replace_with_zero:
val = 0.0
else:
val = x.mean()
# same to x.masked_fill_(mask, val)
y = paddle.full(x.shape, val, x.dtype)
x = paddle.where(mask, y, x)
return x.view(*original_size)
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in

@ -23,7 +23,9 @@ import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.utility import log_add
@ -31,42 +33,41 @@ from paddlespeech.s2t.utils.utility import log_add
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):
super().__init__()
init_type = config.get("init_type", None)
with DefaultInitializerContext(init_type):
self.config = config
wav2vec2_config = Wav2Vec2ConfigPure(config)
wav2vec2 = Wav2Vec2Model(wav2vec2_config)
self.normalize_wav = config.normalize_wav
self.output_norm = config.output_norm
if hasattr(config, 'spec_augment'):
self.spec_augment = SpecAugment(**config.spec_augment)
if config.freeze_wav2vec2:
wav2vec2.eval()
for parm in wav2vec2.parameters():
parm.trainable = False
self.wav2vec2 = wav2vec2
self.enc = VanillaNN(
input_shape=[None, None, wav2vec2_config.hidden_size],
activation=nn.LeakyReLU,
dnn_blocks=config.dnn_blocks,
dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim,
enc_n_units=config.dnn_neurons,
blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate,
self.enc = VanillaNN(**config.enc)
self.ctc = CTC(**config.ctc,
odim=config.output_dim,
batch_average=False,
reduction='mean')
def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
wav = F.layer_norm(wav, wav.shape)
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
out = F.layer_norm(out, out.shape)
if self.train and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)
else:
feats = out
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss

@ -17,6 +17,7 @@ from typing import Dict
from typing import Text
from typing import Union
import paddle
from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types
@ -107,6 +108,125 @@ class ConstantLR(LRScheduler):
return self.base_lr
@register_scheduler
class NewBobScheduler(LRScheduler):
"""Scheduler with new-bob technique, used for LR annealing.
The learning rate is annealed based on the validation performance.
In particular: if (past_loss-current_loss)/past_loss< impr_threshold:
lr=lr * annealing_factor.
Arguments
---------
initial_value : float
The initial hyperparameter value.
annealing_factor : float
It is annealing factor used in new_bob strategy.
improvement_threshold : float
It is the improvement rate between losses used to perform learning
annealing in new_bob strategy.
patient : int
When the annealing condition is violated patient times,
the learning rate is finally reduced.
Example
-------
>>> scheduler = NewBobScheduler(initial_value=1.0)
>>> scheduler(metric_value=10.0)
(1.0, 1.0)
>>> scheduler(metric_value=2.0)
(1.0, 1.0)
>>> scheduler(metric_value=2.5)
(1.0, 0.5)
"""
def __init__(
self,
learning_rate,
last_epoch=-1,
verbose=False,
annealing_factor=0.5,
improvement_threshold=0.0025,
patient=0, ):
self.hyperparam_value = learning_rate
self.annealing_factor = annealing_factor
self.improvement_threshold = improvement_threshold
self.patient = patient
self.metric_values = []
self.current_patient = self.patient
super().__init__(learning_rate, last_epoch, verbose)
def step(self, metric_value=None):
"""
``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
"""
if metric_value is None:
self.last_epoch += 1
self.last_lr = self.hyperparam_value
else:
self.last_epoch += 1
self.last_lr = self.get_lr(metric_value)
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def get_lr(self, metric_value):
"""Returns the current and new value for the hyperparameter.
Arguments
---------
metric_value : int
A number for determining whether to change the hyperparameter value.
"""
new_value = self.hyperparam_value
if len(self.metric_values) > 0:
prev_metric = self.metric_values[-1]
# Update value if improvement too small and patience is 0
if prev_metric == 0: # Prevent division by zero
improvement = 0
else:
improvement = (prev_metric - metric_value) / prev_metric
if improvement < self.improvement_threshold:
if self.current_patient == 0:
new_value *= self.annealing_factor
self.current_patient = self.patient
else:
self.current_patient -= 1
# Store relevant info
self.metric_values.append(metric_value)
self.hyperparam_value = new_value
return new_value
def save(self):
"""Saves the current metrics on the specified path."""
data = {
"current_epoch_index": self.last_epoch,
"hyperparam_value": self.hyperparam_value,
"metric_values": self.metric_values,
"current_patient": self.current_patient
}
return data
def load(self, data):
"""Loads the needed information."""
data = paddle.load(data)
self.last_epoch = data["current_epoch_index"]
self.hyperparam_value = data["hyperparam_value"]
self.metric_values = data["metric_values"]
self.current_patient = data["current_patient"]
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.

Loading…
Cancel
Save