add spec aug

pull/578/head
Hui Zhang 5 years ago
parent 9ad688c6aa
commit 3e449d6500

@ -637,7 +637,7 @@
{
"cell_type": "code",
"execution_count": 59,
"id": "threaded-grove",
"id": "legitimate-overhead",
"metadata": {},
"outputs": [
{
@ -660,7 +660,7 @@
{
"cell_type": "code",
"execution_count": 35,
"id": "equal-vanilla",
"id": "genuine-feeding",
"metadata": {},
"outputs": [
{
@ -705,7 +705,7 @@
{
"cell_type": "code",
"execution_count": 36,
"id": "gorgeous-stanford",
"id": "bizarre-story",
"metadata": {},
"outputs": [
{
@ -728,7 +728,7 @@
{
"cell_type": "code",
"execution_count": 37,
"id": "geological-actor",
"id": "appointed-brooklyn",
"metadata": {},
"outputs": [
{
@ -748,7 +748,7 @@
{
"cell_type": "code",
"execution_count": 38,
"id": "miniature-ethnic",
"id": "occasional-utilization",
"metadata": {},
"outputs": [],
"source": [
@ -758,7 +758,7 @@
{
"cell_type": "code",
"execution_count": 40,
"id": "honest-clarity",
"id": "trained-indonesian",
"metadata": {},
"outputs": [
{
@ -776,7 +776,7 @@
{
"cell_type": "code",
"execution_count": 54,
"id": "environmental-stewart",
"id": "following-brave",
"metadata": {},
"outputs": [
{
@ -809,7 +809,7 @@
{
"cell_type": "code",
"execution_count": 42,
"id": "trying-brazil",
"id": "prospective-blind",
"metadata": {},
"outputs": [
{
@ -839,7 +839,7 @@
{
"cell_type": "code",
"execution_count": 43,
"id": "chronic-interval",
"id": "minus-ethernet",
"metadata": {},
"outputs": [
{
@ -868,7 +868,7 @@
{
"cell_type": "code",
"execution_count": 44,
"id": "widespread-basin",
"id": "ordinary-closer",
"metadata": {},
"outputs": [
{
@ -900,7 +900,7 @@
{
"cell_type": "code",
"execution_count": 45,
"id": "clinical-lighting",
"id": "demographic-mumbai",
"metadata": {},
"outputs": [],
"source": [
@ -911,7 +911,7 @@
{
"cell_type": "code",
"execution_count": 46,
"id": "federal-supervision",
"id": "conscious-stuff",
"metadata": {},
"outputs": [
{
@ -931,7 +931,7 @@
{
"cell_type": "code",
"execution_count": 30,
"id": "parallel-trademark",
"id": "virgin-dublin",
"metadata": {},
"outputs": [],
"source": [
@ -941,7 +941,7 @@
{
"cell_type": "code",
"execution_count": 31,
"id": "extended-fishing",
"id": "sized-homework",
"metadata": {},
"outputs": [],
"source": [
@ -951,7 +951,7 @@
{
"cell_type": "code",
"execution_count": 47,
"id": "baking-auckland",
"id": "disciplinary-headquarters",
"metadata": {},
"outputs": [
{
@ -969,10 +969,103 @@
"np.allclose(x, samples)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "persistent-synthetic",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "hydraulic-reach",
"metadata": {},
"outputs": [],
"source": [
"np.random.uniform?"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "suitable-house",
"metadata": {},
"outputs": [],
"source": [
"random.uniform?"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "printable-carter",
"metadata": {},
"outputs": [],
"source": [
"np.random.RandomState?"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "considered-interval",
"metadata": {},
"outputs": [],
"source": [
"random.sample?"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "ideal-hurricane",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['3', '5'], dtype='<U1')"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.choice(['5','4', '3'], 2, replace=False)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "skilled-cooler",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['3', '4']"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random.sample(['5','4', '3'], 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "changed-storage",
"id": "adopted-hardware",
"metadata": {},
"outputs": [],
"source": []

@ -18,6 +18,9 @@
- id: check-yaml
- id: check-json
- id: pretty-format-json
args:
- --no-sort-keys
- --autofix
- id: check-merge-conflict
- id: flake8
aergs:
@ -49,3 +52,7 @@
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
#exclude: (?=decoders/swig).*(\.cpp|\.h)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
- id: reorder-python-imports

@ -14,7 +14,9 @@
"""Contains the data augmentation pipeline."""
import json
import random
import numpy as np
# audio augment
from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor
from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
@ -23,6 +25,8 @@ from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmen
from deepspeech.frontend.augmentor.resample import ResampleAugmentor
from deepspeech.frontend.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor
# feature augment
from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor
class AugmentationPipeline():
@ -84,9 +88,12 @@ class AugmentationPipeline():
"""
def __init__(self, augmentation_config: str, random_seed=0):
self._rng = random.Random(random_seed)
self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug')
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config)
augmentation_config, 'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
augmentation_config, 'feature')
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
@ -100,15 +107,41 @@ class AugmentationPipeline():
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
def _parse_pipeline_from(self, config_json):
def transform_feature(self, spec_segment):
"""spectrogram augmentation.
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment
def _parse_pipeline_from(self, config_json, aug_type='audio'):
"""Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature'), aug_type
try:
configs = json.loads(config_json)
audio_confs = []
feature_confs = []
for config in configs:
if config["type"] in self._spec_types:
feature_confs.append(config)
else:
audio_confs.append(config)
if aug_type == 'audio':
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in configs
for config in aug_confs
]
rates = [config["prob"] for config in configs]
rates = [config["prob"] for config in aug_confs]
except Exception as e:
raise ValueError("Failed to parse the augmentation config json: "
"%s" % str(e))
@ -130,5 +163,7 @@ class AugmentationPipeline():
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
elif augmentor_type == "specaug":
return SpecAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)

@ -43,15 +43,13 @@ class AugmentorBase():
pass
@abstractmethod
def transform_spectrogram(self, spec_segment):
"""Adds various effects to the input spectrogram segment. Such effects
def transform_feature(self, spec_segment):
"""Adds various effects to the input audo feature segment. Such effects
will augment the training data to make the model invariant to certain
types of time_mask or freq_mask in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param spec_segment: Spectrogram segment to add effects to.
:type spec_segment: Spectrogram
Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to.
"""
pass

@ -39,6 +39,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
impulse_json = self._rng.sample(self._impulse_manifest, 1)[0]
impulse_json = self._rng.choice(
self._impulse_manifest, 1, replace=False)[0]
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
audio_segment.convolve(impulse_segment, allow_resample=True)

@ -45,7 +45,7 @@ class NoisePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
noise_json = self._rng.choice(self._noise_manifest, 1, replace=False)[0]
if noise_json['duration'] < audio_segment.duration:
raise RuntimeError("The duration of sampled noise audio is smaller "
"than the audio segment to add effects to.")

@ -0,0 +1,169 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the volume perturb augmentation model."""
import logging
import numpy as np
from deepspeech.frontend.augmentor.base import AugmentorBase
logger = logging.getLogger(__name__)
class SpecAugmentor(AugmentorBase):
"""Augmentation model for Time warping, Frequency masking, Time masking.
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def __init__(self,
rng,
F,
T,
n_freq_masks,
n_time_masks,
p=1.0,
W=40,
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
F (int): parameter for frequency masking
T (int): parameter for time masking
n_freq_masks (int): number of frequency masks
n_time_masks (int): number of time masks
p (float): parameter for upperbound of the time mask
W (int): parameter for time warping
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
"""
super().__init__()
self._rng = rng
self.W = W
self.F = F
self.T = T
self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks
self.p = p
# adaptive SpecAugment
self.adaptive_number_ratio = adaptive_number_ratio
self.adaptive_size_ratio = adaptive_size_ratio
self.max_n_time_masks = max_n_time_masks
if adaptive_number_ratio > 0:
self.n_time_masks = 0
logger.info('n_time_masks is set ot zero for adaptive SpecAugment.')
if adaptive_size_ratio > 0:
self.T = 0
logger.info('T is set to zero for adaptive SpecAugment.')
self._freq_mask = None
self._time_mask = None
def librispeech_basic(self):
self.W = 80
self.F = 27
self.T = 100
self.n_freq_masks = 1
self.n_time_masks = 1
self.p = 1.0
def librispeech_double(self):
self.W = 80
self.F = 27
self.T = 100
self.n_freq_masks = 2
self.n_time_masks = 2
self.p = 1.0
def switchboard_mild(self):
self.W = 40
self.F = 15
self.T = 70
self.n_freq_masks = 2
self.n_time_masks = 2
self.p = 0.2
def switchboard_strong(self):
self.W = 40
self.F = 27
self.T = 70
self.n_freq_masks = 2
self.n_time_masks = 2
self.p = 0.2
@property
def freq_mask(self):
return self._freq_mask
@property
def time_mask(self):
return self._time_mask
def time_warp(xs, W=40):
raise NotImplementedError
def mask_freq(self, xs, replace_with_zero=False):
n_bins = xs.shape[0]
for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f
self._freq_mask = (f_0, f_0 + f)
return xs
def mask_time(self, xs, replace_with_zero=False):
n_frames = xs.shape[1]
if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio)
n_masks = min(n_masks, self.max_n_time_masks)
else:
n_masks = self.n_time_masks
if self.adaptive_size_ratio > 0:
T = self.adaptive_size_ratio * n_frames
else:
T = self.T
for i in range(n_masks):
t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t
self._time_mask = (t_0, t_0 + t)
return xs
def transform_feature(self, xs: np.ndarray):
"""
Args:
xs (FloatTensor): `[F, T]`
Returns:
xs (FloatTensor): `[F, T]`
"""
# xs = self.time_warp(xs)
xs = self.mask_freq(xs)
xs = self.mask_time(xs)
return xs

@ -192,7 +192,7 @@ class ManifestDataset(Dataset):
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._audio_augmentation_pipeline = AugmentationPipeline(
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
@ -295,11 +295,14 @@ class ManifestDataset(Dataset):
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._audio_augmentation_pipeline.transform_audio(speech_segment)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
return specgram, transcript_part
def _instance_reader_creator(self, manifest):

@ -5,6 +5,7 @@ Data augmentation has often been a highly effective technique to boost the deep
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.
* Audio
- Volume Perturbation
- Speed Perturbation
- Shifting Perturbation
@ -12,6 +13,10 @@ Six optional augmentation components are provided to be selected, configured and
- Noise Perturbation (need background noise audio files)
- Impulse Response (need impulse audio files)
* Feature
- SpecAugment
- Adaptive SpecAugment
In order to inform the trainer of what augmentation components are needed and what their processing orders are, it is required to prepare in advance an *augmentation configuration file* in [JSON](http://www.json.org/) format. For example:
```
@ -31,6 +36,6 @@ In order to inform the trainer of what augmentation components are needed and wh
When the `augment_conf_file` argument is set to the path of the above example configuration file, every audio clip in every epoch will be processed: with 60% of chance, it will first be speed perturbed with a uniformly random sampled speed-rate between 0.95 and 1.05, and then with 80% of chance it will be shifted in time with a random sampled offset between -5 ms and 5 ms. Finally this newly synthesized audio clip will be feed into the feature extractor for further training.
For other configuration examples, please refer to `examples/conf/augmentation.config.example`.
For other configuration examples, please refer to `examples/conf/augmentation.example.json`.
Be careful when utilizing the data augmentation technique, as improper augmentation will do harm to the training, due to the enlarged train-test gap.

@ -40,4 +40,4 @@ python3 utils/build_vocab.py \
--manifest_paths examples/librispeech/data/manifest.train
```
It will write a vocabuary file `examples/librispeech/data/eng_vocab.txt` with all transcription text in `examples/librispeech/data/manifest.train`, without vocabulary truncation (`--count_threshold 0`).
It will write a vocabuary file `examples/librispeech/data/vocab.txt` with all transcription text in `examples/librispeech/data/manifest.train`, without vocabulary truncation (`--count_threshold 0`).

@ -8,10 +8,10 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin
## Setup
- Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost` and `swig`, e.g. installing them via `apt-get`:
- Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`:
```bash
sudo apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
```
or, installing them via `yum`:

@ -1,8 +0,0 @@
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]

@ -0,0 +1,10 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -5,7 +5,7 @@ data:
test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.npz
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.config
augmentation_config: conf/augmentation.json
batch_size: 64 # one gpu
max_duration: 27.0
min_duration: 0.0

@ -0,0 +1,34 @@
[
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 1.0
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 0.0
}
]

@ -0,0 +1,110 @@
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
delta_delta: False
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 25.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
# network architecture
model:
cmvn_file: "data/mean_std.npz"
cmvn_file_type: "npz"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
use_cnn_module: True
cnn_module_kernel: 15
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
training:
n_epoch: 240
accum_grad: 4
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
decoding:
batch_size: 16
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 10
cutoff_prob: 1.0
cutoff_top_n: 0
num_proc_bsearch: 8
ctc_weight: 0.0 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -1,8 +0,0 @@
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]

@ -1,40 +0,0 @@
[
{
"type": "noise",
"params": {"min_snr_dB": 40,
"max_snr_dB": 50,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.6
},
{
"type": "impulse",
"params": {"impulse_manifest_path": "datasets/manifest.impulse"},
"prob": 0.5
},
{
"type": "speed",
"params": {"min_speed_rate": 0.95,
"max_speed_rate": 1.05,
"num_rates": 3},
"prob": 0.5
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]

@ -0,0 +1,67 @@
[
{
"type": "noise",
"params": {
"min_snr_dB": 40,
"max_snr_dB": 50,
"noise_manifest_path": "datasets/manifest.noise"
},
"prob": 0.6
},
{
"type": "impulse",
"params": {
"impulse_manifest_path": "datasets/manifest.impulse"
},
"prob": 0.5
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.95,
"max_speed_rate": 1.05,
"num_rates": 3
},
"prob": 0.5
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "volume",
"params": {
"min_gain_dBFS": -10,
"max_gain_dBFS": 10
},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {
"target_db": -20,
"prior_db": -20,
"prior_samples": 100
},
"prob": 0.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 0.0
}
]

@ -0,0 +1,10 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -1,8 +0,0 @@
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]

@ -0,0 +1,10 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -5,7 +5,7 @@ data:
test_manifest: data/manifest.test-clean
mean_std_filepath: data/mean_std.npz
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.config
augmentation_config: conf/augmentation.json
batch_size: 20
max_duration: 27.0
min_duration: 0.0

@ -1,8 +0,0 @@
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]

@ -0,0 +1,10 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -5,7 +5,7 @@ data:
test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.npz
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.config
augmentation_config: conf/augmentation.json
batch_size: 4
max_duration: 27.0
min_duration: 0.0

@ -1,8 +0,0 @@
[
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
}
]

@ -0,0 +1,10 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -7,7 +7,7 @@ data:
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.config
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5
max_input_len: 20.0

Loading…
Cancel
Save