From 9ad688c6aa10e2fff5d4530c3b7f45a221fcee46 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 14 Apr 2021 06:51:23 +0000 Subject: [PATCH] speed perturb with sox --- .notebook/python_test.ipynb | 316 +++++++++++++++++- deepspeech/frontend/audio.py | 15 +- .../frontend/augmentor/speed_perturb.py | 85 +++-- examples/aishell/s0/local/data.sh | 2 +- examples/aishell/s1/conf/augmentation.config | 8 - examples/aishell/s1/local/train.sh | 4 - examples/aug_conf/augmentation.config.example | 3 +- examples/tiny/s1/conf/conformer.yaml | 10 +- requirements.txt | 1 + setup.sh | 2 +- 10 files changed, 398 insertions(+), 48 deletions(-) delete mode 100644 examples/aishell/s1/conf/augmentation.config diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb index 50d5a8331..af55de5a4 100644 --- a/.notebook/python_test.ipynb +++ b/.notebook/python_test.ipynb @@ -637,7 +637,7 @@ { "cell_type": "code", "execution_count": 59, - "id": "light-drill", + "id": "threaded-grove", "metadata": {}, "outputs": [ { @@ -657,10 +657,322 @@ "get_default_args(io.open)" ] }, + { + "cell_type": "code", + "execution_count": 35, + "id": "equal-vanilla", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sox in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (1.4.1)\n", + "Requirement already satisfied: numpy>=1.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from sox) (1.20.1)\n", + "Requirement already satisfied: librosa in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (0.8.0)\n", + "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.24.1)\n", + "Requirement already satisfied: numba>=0.43.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.52.0)\n", + "Requirement already satisfied: pooch>=1.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.3.0)\n", + "Requirement already satisfied: scipy>=1.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.2.1)\n", + "Requirement already satisfied: numpy>=1.15.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.20.1)\n", + "Requirement already satisfied: decorator>=3.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (4.4.2)\n", + "Requirement already satisfied: resampy>=0.2.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.2.2)\n", + "Requirement already satisfied: audioread>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (2.1.9)\n", + "Requirement already satisfied: soundfile>=0.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.9.0.post1)\n", + "Requirement already satisfied: joblib>=0.14 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.0.1)\n", + "Requirement already satisfied: setuptools in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (51.0.0)\n", + "Requirement already satisfied: llvmlite<0.36,>=0.35.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (0.35.0)\n", + "Requirement already satisfied: appdirs in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (1.4.4)\n", + "Requirement already satisfied: packaging in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (20.9)\n", + "Requirement already satisfied: requests in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (2.25.1)\n", + "Requirement already satisfied: six>=1.3 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from resampy>=0.2.2->librosa) (1.15.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)\n", + "Requirement already satisfied: cffi>=0.6 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)\n", + "Requirement already satisfied: pycparser in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cffi>=0.6->soundfile>=0.9.0->librosa) (2.20)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)\n", + "Requirement already satisfied: idna<3,>=2.5 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2.10)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (1.26.3)\n" + ] + } + ], + "source": [ + "!pip install sox\n", + "!pip install librosa" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "gorgeous-stanford", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import sox\n", + "tfm = sox.Transformer()\n", + "sample_rate = 44100\n", + "y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)\n", + "print(y.dtype.type)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "geological-actor", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0. 0.06264832 0.12505052 ... -0.18696144 -0.12505052\n", + " -0.06264832]\n" + ] + } + ], + "source": [ + "output_array = tfm.build_array(input_array=y, sample_rate_in=sample_rate)\n", + "print(output_array)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "miniature-ethnic", + "metadata": {}, + "outputs": [], + "source": [ + "tfm.build_array?" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "honest-clarity", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['8svx', 'aif', 'aifc', 'aiff', 'aiffc', 'al', 'amb', 'amr-nb', 'amr-wb', 'anb', 'au', 'avr', 'awb', 'caf', 'cdda', 'cdr', 'cvs', 'cvsd', 'cvu', 'dat', 'dvms', 'f32', 'f4', 'f64', 'f8', 'fap', 'flac', 'fssd', 'gsm', 'gsrt', 'hcom', 'htk', 'ima', 'ircam', 'la', 'lpc', 'lpc10', 'lu', 'mat', 'mat4', 'mat5', 'maud', 'nist', 'ogg', 'paf', 'prc', 'pvf', 'raw', 's1', 's16', 's2', 's24', 's3', 's32', 's4', 's8', 'sb', 'sd2', 'sds', 'sf', 'sl', 'sln', 'smp', 'snd', 'sndfile', 'sndr', 'sndt', 'sou', 'sox', 'sph', 'sw', 'txw', 'u1', 'u16', 'u2', 'u24', 'u3', 'u32', 'u4', 'u8', 'ub', 'ul', 'uw', 'vms', 'voc', 'vorbis', 'vox', 'w64', 'wav', 'wavpcm', 'wv', 'wve', 'xa', 'xi']\n" + ] + } + ], + "source": [ + "print(sox.core._get_valid_formats())" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "environmental-stewart", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float64\n", + "(59471,)\n", + "16000\n", + "(54065,)\n", + "1.0999907518727459\n" + ] + } + ], + "source": [ + "import soundfile as sf\n", + "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/dev/S0724/BAC009S0724W0190.wav'\n", + "samples, sr = sf.read(wav)\n", + "print(samples.dtype)\n", + "print(samples.shape)\n", + "print(sr)\n", + "tfm = sox.Transformer()\n", + "tfm.speed(1.1)\n", + "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", + "output_array.dtype\n", + "print(output_array.shape)\n", + "print(len(samples)/len(output_array))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "trying-brazil", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import IPython.display as ipd\n", + "ipd.Audio(wav) # load a local WAV file" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "chronic-interval", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ipd.Audio(output_array, rate=sr) # load a NumPy array" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "widespread-basin", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tfm = sox.Transformer()\n", + "tfm.speed(0.9)\n", + "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", + "ipd.Audio(output_array, rate=sr) # load a NumPy array" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "clinical-lighting", + "metadata": {}, + "outputs": [], + "source": [ + "import librosa\n", + "x, sr = librosa.load(wav, sr=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "federal-supervision", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32\n", + "float64\n" + ] + } + ], + "source": [ + "print(x.dtype)\n", + "print(samples.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "parallel-trademark", + "metadata": {}, + "outputs": [], + "source": [ + "sf.read?" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "extended-fishing", + "metadata": {}, + "outputs": [], + "source": [ + "librosa.load?" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "baking-auckland", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.allclose(x, samples)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "protective-belgium", + "id": "changed-storage", "metadata": {}, "outputs": [], "source": [] diff --git a/deepspeech/frontend/audio.py b/deepspeech/frontend/audio.py index 3ed50e766..10a26db2b 100644 --- a/deepspeech/frontend/audio.py +++ b/deepspeech/frontend/audio.py @@ -22,6 +22,7 @@ import resampy from scipy import signal import random import copy +import sox class AudioSegment(object): @@ -323,11 +324,15 @@ class AudioSegment(object): """ if speed_rate <= 0: raise ValueError("speed_rate should be greater than zero.") - old_length = self._samples.shape[0] - new_length = int(old_length / speed_rate) - old_indices = np.arange(old_length) - new_indices = np.linspace(start=0, stop=old_length, num=new_length) - self._samples = np.interp(new_indices, old_indices, self._samples) + # old_length = self._samples.shape[0] + # new_length = int(old_length / speed_rate) + # old_indices = np.arange(old_length) + # new_indices = np.linspace(start=0, stop=old_length, num=new_length) + # self._samples = np.interp(new_indices, old_indices, self._samples) + tfm = sox.Transformer() + tfm.speed(speed_rate) + self._samples = tfm.build_array( + input_array=self._samples, sample_rate_in=self._sample_rate) def normalize(self, target_db=-20, max_gain_db=300.0): """Normalize audio to be of the desired RMS value in decibels. diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py index 6518382db..f3fbdd629 100644 --- a/deepspeech/frontend/augmentor/speed_perturb.py +++ b/deepspeech/frontend/augmentor/speed_perturb.py @@ -13,35 +13,71 @@ # limitations under the License. """Contain the speech perturbation augmentation model.""" +import numpy as np from deepspeech.frontend.augmentor.base import AugmentorBase class SpeedPerturbAugmentor(AugmentorBase): - """Augmentation model for adding speed perturbation. - - See reference paper here: - http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf - - :param rng: Random generator object. - :type rng: random.Random - :param min_speed_rate: Lower bound of new speed rate to sample and should - not be smaller than 0.9. - :type min_speed_rate: float - :param max_speed_rate: Upper bound of new speed rate to sample and should - not be larger than 1.1. - :type max_speed_rate: float - """ - - def __init__(self, rng, min_speed_rate, max_speed_rate): + """Augmentation model for adding speed perturbation.""" + + def __init__(self, rng, min_speed_rate=0.9, max_speed_rate=1.1, + num_rates=3): + """speed perturbation. + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Sox speed: + https://pysox.readthedocs.io/en/latest/api.html#sox.transform.Transformer + + See reference paper here: + http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf + + Espnet: + https://espnet.github.io/espnet/_modules/espnet/transform/perturb.html + + Nemo: + https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/perturb.py#L92 + + Args: + rng (random.Random): Random generator object. + min_speed_rate (float): Lower bound of new speed rate to sample and should + not be smaller than 0.9. + max_speed_rate (float): Upper bound of new speed rate to sample and should + not be larger than 1.1. + num_rates (int, optional): Number of discrete rates to allow. + Can be a positive or negative integer. Defaults to 3. + If a positive integer greater than 0 is provided, the range of + speed rates will be discretized into `num_rates` values. + If a negative integer or 0 is provided, the full range of speed rates + will be sampled uniformly. + Note: If a positive integer is provided and the resultant discretized + range of rates contains the value '1.0', then those samples with rate=1.0, + will not be augmented at all and simply skipped. This is to unnecessary + augmentation and increase computation time. Effective augmentation chance + in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance + where `prob` is the global probability of a sample being augmented. + + Raises: + ValueError: when speed_rate error + """ if min_speed_rate < 0.9: raise ValueError( "Sampling speed below 0.9 can cause unnatural effects") if max_speed_rate > 1.1: raise ValueError( "Sampling speed above 1.1 can cause unnatural effects") - self._min_speed_rate = min_speed_rate - self._max_speed_rate = max_speed_rate + self._min_rate = min_speed_rate + self._max_rate = max_speed_rate self._rng = rng + self._num_rates = num_rates + if num_rates > 0: + self._rates = np.linspace( + self._min_rate, self._max_rate, self._num_rates, endpoint=True) def transform_audio(self, audio_segment): """Sample a new speed rate from the given range and @@ -52,6 +88,13 @@ class SpeedPerturbAugmentor(AugmentorBase): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegment|SpeechSegment """ - sampled_speed = self._rng.uniform(self._min_speed_rate, - self._max_speed_rate) - audio_segment.change_speed(sampled_speed) + if self._num_rates < 0: + speed_rate = self._rng.uniform(self._min_rate, self._max_rate) + else: + speed_rate = self._rng.choice(self._rates) + + # Skip perturbation in case of identity speed rate + if speed_rate == 1.0: + return + + audio_segment.change_speed(speed_rate) diff --git a/examples/aishell/s0/local/data.sh b/examples/aishell/s0/local/data.sh index 85acb23ea..fb2700083 100644 --- a/examples/aishell/s0/local/data.sh +++ b/examples/aishell/s0/local/data.sh @@ -25,7 +25,7 @@ python3 ${MAIN_ROOT}/utils/build_vocab.py \ --unit_type="char" \ --count_threshold=0 \ --vocab_path="data/vocab.txt" \ ---manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" +--manifest_paths "data/manifest.train.raw" if [ $? -ne 0 ]; then echo "Build vocabulary failed. Terminated." diff --git a/examples/aishell/s1/conf/augmentation.config b/examples/aishell/s1/conf/augmentation.config deleted file mode 100644 index 6c24da549..000000000 --- a/examples/aishell/s1/conf/augmentation.config +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "type": "shift", - "params": {"min_shift_ms": -5, - "max_shift_ms": 5}, - "prob": 1.0 - } -] diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index c286566a8..8ed5010ee 100644 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -1,9 +1,5 @@ #! /usr/bin/env bash -# train model -# if you wish to resume from an exists model, uncomment --init_from_pretrained_model -export FLAGS_sync_nccl_allreduce=0 - ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." diff --git a/examples/aug_conf/augmentation.config.example b/examples/aug_conf/augmentation.config.example index 21ed6ee10..2902125ab 100644 --- a/examples/aug_conf/augmentation.config.example +++ b/examples/aug_conf/augmentation.config.example @@ -14,7 +14,8 @@ { "type": "speed", "params": {"min_speed_rate": 0.95, - "max_speed_rate": 1.05}, + "max_speed_rate": 1.05, + "num_rates": 3}, "prob": 0.5 }, { diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 7d4303660..e4c6f33c1 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -9,12 +9,12 @@ data: mean_std_filepath: "" augmentation_config: conf/augmentation.config batch_size: 4 - max_input_len: 27.0 - min_input_len: 0.0 - max_output_len: .INF + min_input_len: 0.5 + max_input_len: 20.0 min_output_len: 0.0 - max_output_input_ratio: .INF - min_output_input_ratio: 0.0 + max_output_len: 400 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 diff --git a/requirements.txt b/requirements.txt index 40b4cef38..2d021a4b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ resampy==0.2.2 scipy==1.2.1 sentencepiece SoundFile==0.9.0.post1 +sox tensorboardX typeguard yacs diff --git a/setup.sh b/setup.sh index c681583b8..a58bd7967 100644 --- a/setup.sh +++ b/setup.sh @@ -7,7 +7,7 @@ fi if [ -e /etc/lsb-release ];then #${SUDO} apt-get update - ${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + ${SUDO} apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev fi # install python dependencies