You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/frontend/augmentor/augmentation.py

231 lines
8.6 KiB

4 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the data augmentation pipeline."""
import json
import os
from collections.abc import Sequence
from inspect import signature
from pprint import pformat
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
import numpy as np
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["AugmentationPipeline"]
import_alias = dict(
volume="paddlespeech.s2t.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
shift="paddlespeech.s2t.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
speed="paddlespeech.s2t.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
resample="paddlespeech.s2t.frontend.augmentor.resample:ResampleAugmentor",
bayesian_normal="paddlespeech.s2t.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
noise="paddlespeech.s2t.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
impulse="paddlespeech.s2t.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
specaug="paddlespeech.s2t.frontend.augmentor.spec_augment:SpecAugmentor", )
Support paddle 2.x (#538) * 2.x model * model test pass * fix data * fix soundfile with flac support * one thread dataloader test pass * export feasture size add trainer and utils add setup model and dataloader update travis using Bionic dist * add venv; test under venv * fix unittest; train and valid * add train and config * add config and train script * fix ctc cuda memcopy error * fix imports * fix train valid log * fix dataset batch shuffle shift start from 1 fix rank_zero_only decreator error close tensorboard when train over add decoding config and code * test process can run * test with decoding * test and infer with decoding * fix infer * fix ctc loss lr schedule sortagrad logger * aishell egs * refactor train add aishell egs * fix dataset batch shuffle and add batch sampler log print model parameter * fix model and ctc * sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp add grad clip by global norm add model train test notebook * ctc loss remove run prefix using ord value as text id * using unk when training compute_loss need text ids ord id using in test mode, which compute wer/cer * fix tester * add lr_deacy refactor code * fix tools * fix ci add tune fix gru model bugs add dataset and model test * fix decoding * refactor repo fix decoding * fix musan and rir dataset * refactor io, loss, conv, rnn, gradclip, model, utils * fix ci and import * refactor model add export jit model * add deploy bin and test it * rm uselss egs * add layer tools * refactor socket server new model from pretrain * remve useless * fix instability loss and grad nan or inf for librispeech training * fix sampler * fix libri train.sh * fix doc * add license on cpp * fix doc * fix libri script * fix install * clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
class AugmentationPipeline():
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
[ {
"type": "noise",
"params": {"min_snr_dB": 10,
"max_snr_dB": 20,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 1.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
Params:
preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed.
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
3 years ago
SPEC_TYPES = {'specaug'}
3 years ago
def __init__(self, preprocess_conf: str, random_seed: int=0):
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
self._rng = np.random.RandomState(random_seed)
3 years ago
self.conf = {'mode': 'sequential', 'process': []}
if preprocess_conf:
if os.path.isfile(preprocess_conf):
# json file
with open(preprocess_conf, 'r') as fin:
json_string = fin.read()
else:
# json string
json_string = preprocess_conf
process = json.loads(json_string)
3 years ago
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
'audio')
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature')
logger.info(
f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx, (func, rate) in enumerate(
zip(self._augmentors, self._rates), 0):
if self._rng.uniform(0., 1.) >= rate:
continue
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logger.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
def transform_feature(self, spec_segment):
"""spectrogram augmentation.
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment
def _parse_pipeline_from(self, aug_type='all'):
"""Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature', 'all'), aug_type
audio_confs = []
feature_confs = []
all_confs = []
3 years ago
for config in self.conf['process']:
all_confs.append(config)
3 years ago
if config["type"] in self.SPEC_TYPES:
feature_confs.append(config)
else:
audio_confs.append(config)
if aug_type == 'audio':
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
elif aug_type == 'all':
aug_confs = all_confs
else:
raise ValueError(f"Not support: {aug_type}")
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in aug_confs
]
rates = [config["prob"] for config in aug_confs]
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
class_obj = dynamic_import(augmentor_type, import_alias)
assert issubclass(class_obj, AugmentorBase)
try:
obj = class_obj(self._rng, **params)
except Exception:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
return obj