From 718407b77d33fabb2a91ff44fc6488cfbde121fb Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 20 Aug 2021 02:54:21 +0000 Subject: [PATCH 01/17] add seed --- deepspeech/exps/deepspeech2/bin/train.py | 4 +++ deepspeech/exps/deepspeech2/model.py | 10 +++++++ deepspeech/exps/u2/bin/train.py | 3 +++ deepspeech/exps/u2/model.py | 8 +++++- deepspeech/models/ds2_online/deepspeech2.py | 30 +++++++++++---------- tests/deepspeech2_online_model_test.py | 2 +- 6 files changed, 41 insertions(+), 16 deletions(-) diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 69ff043a..32127022 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer for DeepSpeech2 model.""" +import os + from paddle import distributed as dist from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -53,5 +55,7 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) + if config.training.seed != None: + os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 65c905a1..a2bbee5e 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" +import os +import random import time from collections import defaultdict from pathlib import Path @@ -53,6 +55,7 @@ class DeepSpeech2Trainer(Trainer): weight_decay=1e-6, # the coeff of weight decay global_grad_clip=5.0, # the global norm clip n_epoch=50, # train epochs + seed=1024, #train seed )) if config is not None: @@ -61,6 +64,13 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) + if config.training.seed != None: + self.set_seed(config.training.seed) + + def set_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): start = time.time() diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041d..ebd91faa 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -52,7 +52,10 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) + if config.training.seed != None: + os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') + main(config, args) # Setting for profiling pr = cProfile.Profile() pr.runcall(main, config, args) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index d661f078..b248c5a6 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -55,7 +55,7 @@ class U2Trainer(Trainer): log_interval=100, # steps accum_grad=1, # accum grad by # steps global_grad_clip=5.0, # the global norm clip - )) + seed=1024, )) default.optim = 'adam' default.optim_conf = CfgNode( dict( @@ -75,6 +75,12 @@ class U2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) + if config.training.seed != None: + self.set_seed(config.training.seed) + + def set_seed(self, seed): + np.random.seed(seed) + paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 3083e4b2..e130968b 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer): Args: x (Tensor): [B, feature_size, D] x_lens (Tensor): [B] - init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - Returns: + init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + Return: x (Tensor): encoder outputs, [B, size, D] x_lens (Tensor): encoder length, [B] - final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ if init_state_h_box is not None: init_state_list = None @@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer): if self.use_gru == True: final_chunk_state_h_box = paddle.concat( final_chunk_state_list, axis=0) - final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) + final_chunk_state_c_box = init_state_c_box else: final_chunk_state_h_list = [ final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) @@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer): x_lens (Tensor): [B] decoder_chunk_size: The chunk size of decoder Returns: - eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks - eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks - final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size - final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size + eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] """ subsampling_rate = self.conv.subsampling_rate receptive_field_length = self.conv.receptive_field_length @@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer): """The DeepSpeech2 network structure for online. - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable + :param audio: Audio spectrogram data layer. + :type audio: Variable + :param text: Transcription text data layer. + :type text: Variable :param audio_len: Valid sequence length data layer. :type audio_len: Variable + :param feat_size: feature size for audio. + :type feat_size: int :param dict_size: Dictionary size for tokenized transcription. :type dict_size: int :param num_conv_layers: Number of stacking convolution layers. diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py index 87f04887..0e31b85f 100644 --- a/tests/deepspeech2_online_model_test.py +++ b/tests/deepspeech2_online_model_test.py @@ -143,7 +143,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) decode_max_len = eouts.shape[1] eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] - self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True) self.assertEqual( paddle.allclose(final_state_h_box, final_state_h_box_chk), True) if use_gru == False: From d065824bd3c5e9f7a397e4a0be990509403c7ed8 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 20 Aug 2021 03:22:52 +0000 Subject: [PATCH 02/17] fix the bug of 'import path error' for ds2 --- .notebook/jit_infer.ipynb | 6 +++--- deepspeech/exps/deepspeech2/bin/deploy/runtime.py | 2 +- deepspeech/exps/deepspeech2/bin/deploy/server.py | 2 +- deepspeech/exps/deepspeech2/bin/tune.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb index ba50d874..20882c1a 100644 --- a/.notebook/jit_infer.ipynb +++ b/.notebook/jit_infer.ipynb @@ -83,8 +83,8 @@ "from deepspeech.frontend.utility import read_manifest\n", "from deepspeech.utils.utility import add_arguments, print_arguments\n", "\n", - "from deepspeech.models.deepspeech2 import DeepSpeech2Model\n", - "from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n", + "from deepspeech.models.ds2 import DeepSpeech2Model\n", + "from deepspeech.models.ds2 import DeepSpeech2InferModel\n", "from deepspeech.io.dataset import ManifestDataset\n", "\n", "\n", @@ -669,4 +669,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index 01f01b65..21ffa6bf 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -23,7 +23,7 @@ from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrTCPServer diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index b473a8fd..583e9095 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -21,7 +21,7 @@ from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrTCPServer diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index f10dc27c..94a9b6c4 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -21,7 +21,7 @@ from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils import error_rate from deepspeech.utils.utility import add_arguments From 50f10f37ae8224a5d143ecfc59e31af1d992e695 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 20 Aug 2021 03:28:55 +0000 Subject: [PATCH 03/17] support replace with mean by aug --- deepspeech/__init__.py | 41 +------------------ .../frontend/augmentor/impulse_response.py | 2 +- .../frontend/augmentor/noise_perturb.py | 2 +- .../online_bayesian_normalization.py | 2 +- deepspeech/frontend/augmentor/resample.py | 2 +- .../frontend/augmentor/shift_perturb.py | 2 +- deepspeech/frontend/augmentor/spec_augment.py | 21 +++++++--- .../frontend/augmentor/speed_perturb.py | 2 +- .../frontend/augmentor/volume_perturb.py | 2 +- examples/aishell/s0/conf/augmentation.json | 3 +- examples/aishell/s1/conf/augmentation.json | 3 +- examples/aug_conf/augmentation.json | 10 ----- .../augmentation.json} | 3 +- examples/callcenter/s1/conf/augmentation.json | 3 +- .../librispeech/s0/conf/augmentation.json | 3 +- .../librispeech/s1/conf/augmentation.json | 3 +- .../librispeech/s2/conf/augmentation.json | 3 +- examples/timit/s1/conf/augmentation.json | 3 +- examples/tiny/s0/conf/augmentation.json | 25 +++++++++++ examples/tiny/s1/conf/augmentation.json | 3 +- 20 files changed, 66 insertions(+), 72 deletions(-) delete mode 100644 examples/aug_conf/augmentation.json rename examples/{aug_conf/augmentation.example.json => augmentation/augmentation.json} (94%) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 88f81075..fbec5a5e 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'): "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) -########### hcak paddle.nn.functional ############# - - -def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor: - """The gated linear unit (GLU) activation.""" - a, b = x.split(2, axis=axis) - act_b = F.sigmoid(b) - return a * act_b - - -if not hasattr(paddle.nn.functional, 'glu'): - logger.warn( - "register user glu to paddle.nn.functional, remove this when fixed!") - setattr(paddle.nn.functional, 'glu', glu) - -# def softplus(x): -# """Softplus function.""" -# if hasattr(paddle.nn.functional, 'softplus'): -# #return paddle.nn.functional.softplus(x.float()).type_as(x) -# return paddle.nn.functional.softplus(x) -# else: -# raise NotImplementedError - -# def gelu_accurate(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py -# if not hasattr(gelu_accurate, "_a"): -# gelu_accurate._a = math.sqrt(2 / math.pi) -# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * -# (x + 0.044715 * paddle.pow(x, 3)))) - -# def gelu(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# if hasattr(nn.functional, 'gelu'): -# #return nn.functional.gelu(x.float()).type_as(x) -# return nn.functional.gelu(x) -# else: -# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) - ########### hcak paddle.nn ############# class GLU(nn.Layer): @@ -401,7 +362,7 @@ class GLU(nn.Layer): self.dim = dim def forward(self, xs): - return glu(xs, dim=self.dim) + return F.glu(xs, dim=self.dim) if not hasattr(paddle.nn, 'GLU'): diff --git a/deepspeech/frontend/augmentor/impulse_response.py b/deepspeech/frontend/augmentor/impulse_response.py index b1a732ad..818251ed 100644 --- a/deepspeech/frontend/augmentor/impulse_response.py +++ b/deepspeech/frontend/augmentor/impulse_response.py @@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/noise_perturb.py b/deepspeech/frontend/augmentor/noise_perturb.py index 8be5931b..790b0c39 100644 --- a/deepspeech/frontend/augmentor/noise_perturb.py +++ b/deepspeech/frontend/augmentor/noise_perturb.py @@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/online_bayesian_normalization.py b/deepspeech/frontend/augmentor/online_bayesian_normalization.py index 4b5e2301..0f9d3ef6 100644 --- a/deepspeech/frontend/augmentor/online_bayesian_normalization.py +++ b/deepspeech/frontend/augmentor/online_bayesian_normalization.py @@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/resample.py b/deepspeech/frontend/augmentor/resample.py index a8c0c662..509fe003 100644 --- a/deepspeech/frontend/augmentor/resample.py +++ b/deepspeech/frontend/augmentor/resample.py @@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py index a76fb51c..8b7439fe 100644 --- a/deepspeech/frontend/augmentor/shift_perturb.py +++ b/deepspeech/frontend/augmentor/shift_perturb.py @@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index bfa8300a..67b6cfdd 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase): W=40, adaptive_number_ratio=0, adaptive_size_ratio=0, - max_n_time_masks=20): + max_n_time_masks=20, + replace_with_zero=True): """SpecAugment class. Args: rng (random.Random): random generator object. @@ -54,9 +55,11 @@ class SpecAugmentor(AugmentorBase): adaptive_number_ratio (float): adaptive multiplicity ratio for time masking adaptive_size_ratio (float): adaptive size ratio for time masking max_n_time_masks (int): maximum number of time masking + replace_with_zero (bool): pad zero on mask if true else use mean """ super().__init__() self._rng = rng + self.replace_with_zero = replace_with_zero self.W = W self.F = F @@ -124,15 +127,18 @@ class SpecAugmentor(AugmentorBase): return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" def time_warp(xs, W=40): - raise NotImplementedError + return xs def mask_freq(self, xs, replace_with_zero=False): n_bins = xs.shape[0] for i in range(0, self.n_freq_masks): f = int(self._rng.uniform(low=0, high=self.F)) f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) - xs[f_0:f_0 + f, :] = 0 assert f_0 <= f_0 + f + if self.replace_with_zero: + xs[f_0:f_0 + f, :] = 0 + else: + xs[f_0:f_0 + f, :] = xs.mean() self._freq_mask = (f_0, f_0 + f) return xs @@ -154,14 +160,17 @@ class SpecAugmentor(AugmentorBase): t = int(self._rng.uniform(low=0, high=T)) t = min(t, int(n_frames * self.p)) t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) - xs[:, t_0:t_0 + t] = 0 assert t_0 <= t_0 + t + if self.replace_with_zero: + xs[:, t_0:t_0 + t] = 0 + else: + xs[:, t_0:t_0 + t] = xs.mean() self._time_mask = (t_0, t_0 + t) return xs def __call__(self, x, train=True): if not train: - return + return x return self.transform_feature(x) def transform_feature(self, xs: np.ndarray): @@ -171,7 +180,7 @@ class SpecAugmentor(AugmentorBase): Returns: xs (FloatTensor): `[F, T]` """ - # xs = self.time_warp(xs) + xs = self.time_warp(xs) xs = self.mask_freq(xs) xs = self.mask_time(xs) return xs diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py index eec2e551..ce8dfde0 100644 --- a/deepspeech/frontend/augmentor/speed_perturb.py +++ b/deepspeech/frontend/augmentor/speed_perturb.py @@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/deepspeech/frontend/augmentor/volume_perturb.py b/deepspeech/frontend/augmentor/volume_perturb.py index d08f75c3..70cb2889 100644 --- a/deepspeech/frontend/augmentor/volume_perturb.py +++ b/deepspeech/frontend/augmentor/volume_perturb.py @@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase): def __call__(self, x, uttid=None, train=True): if not train: - return + return x self.transform_audio(x) return x diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 1987ad42..81d110b0 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/aishell/s1/conf/augmentation.json b/examples/aishell/s1/conf/augmentation.json index 1987ad42..81d110b0 100644 --- a/examples/aishell/s1/conf/augmentation.json +++ b/examples/aishell/s1/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/aug_conf/augmentation.json b/examples/aug_conf/augmentation.json deleted file mode 100644 index a1a759e6..00000000 --- a/examples/aug_conf/augmentation.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - } -] diff --git a/examples/aug_conf/augmentation.example.json b/examples/augmentation/augmentation.json similarity index 94% rename from examples/aug_conf/augmentation.example.json rename to examples/augmentation/augmentation.json index efae2e5e..baf2cac3 100644 --- a/examples/aug_conf/augmentation.example.json +++ b/examples/augmentation/augmentation.json @@ -60,7 +60,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 0.0 } diff --git a/examples/callcenter/s1/conf/augmentation.json b/examples/callcenter/s1/conf/augmentation.json index 1987ad42..81d110b0 100644 --- a/examples/callcenter/s1/conf/augmentation.json +++ b/examples/callcenter/s1/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/librispeech/s0/conf/augmentation.json b/examples/librispeech/s0/conf/augmentation.json index 1987ad42..81d110b0 100644 --- a/examples/librispeech/s0/conf/augmentation.json +++ b/examples/librispeech/s0/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/librispeech/s1/conf/augmentation.json b/examples/librispeech/s1/conf/augmentation.json index c1078393..7dd158eb 100644 --- a/examples/librispeech/s1/conf/augmentation.json +++ b/examples/librispeech/s1/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/librispeech/s2/conf/augmentation.json b/examples/librispeech/s2/conf/augmentation.json index 49fe333e..cc8c7e00 100644 --- a/examples/librispeech/s2/conf/augmentation.json +++ b/examples/librispeech/s2/conf/augmentation.json @@ -10,7 +10,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/timit/s1/conf/augmentation.json b/examples/timit/s1/conf/augmentation.json index c1078393..7dd158eb 100644 --- a/examples/timit/s1/conf/augmentation.json +++ b/examples/timit/s1/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/tiny/s0/conf/augmentation.json b/examples/tiny/s0/conf/augmentation.json index a1a759e6..8f9ff7fd 100644 --- a/examples/tiny/s0/conf/augmentation.json +++ b/examples/tiny/s0/conf/augmentation.json @@ -1,4 +1,13 @@ [ + { + "type": "speed", + "params": { + "min_speed_rate": 0.9, + "max_speed_rate": 1.1, + "num_rates": 3 + }, + "prob": 1.0 + }, { "type": "shift", "params": { @@ -6,5 +15,21 @@ "max_shift_ms": 5 }, "prob": 1.0 + }, + { + "type": "specaug", + "params": { + "F": 10, + "T": 50, + "n_freq_masks": 2, + "n_time_masks": 2, + "p": 1.0, + "W": 80, + "adaptive_number_ratio": 0, + "adaptive_size_ratio": 0, + "max_n_time_masks": 20, + "replace_with_zero": true + }, + "prob": 1.0 } ] diff --git a/examples/tiny/s1/conf/augmentation.json b/examples/tiny/s1/conf/augmentation.json index f26c282e..8f9ff7fd 100644 --- a/examples/tiny/s1/conf/augmentation.json +++ b/examples/tiny/s1/conf/augmentation.json @@ -27,7 +27,8 @@ "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, - "max_n_time_masks": 20 + "max_n_time_masks": 20, + "replace_with_zero": true }, "prob": 1.0 } From 782f6be42d0bed2440695a34681beb1d112c733d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 20 Aug 2021 09:08:42 +0000 Subject: [PATCH 04/17] (D,T) to (T, D); time warp --- deepspeech/frontend/augmentor/spec_augment.py | 109 ++++++++++++++---- .../frontend/featurizer/audio_featurizer.py | 84 ++++++++------ deepspeech/frontend/normalizer.py | 14 +-- deepspeech/io/collator.py | 3 +- deepspeech/io/collator_st.py | 104 ++++++----------- requirements.txt | 1 + 6 files changed, 178 insertions(+), 137 deletions(-) diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 67b6cfdd..a3f4e268 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains the volume perturb augmentation model.""" +import random + import numpy as np +from PIL import Image +from PIL.Image import BICUBIC from deepspeech.frontend.augmentor.base import AugmentorBase from deepspeech.utils.log import Log @@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase): adaptive_number_ratio=0, adaptive_size_ratio=0, max_n_time_masks=20, - replace_with_zero=True): + replace_with_zero=True, + warp_mode='PIL'): """SpecAugment class. Args: rng (random.Random): random generator object. @@ -56,11 +61,15 @@ class SpecAugmentor(AugmentorBase): adaptive_size_ratio (float): adaptive size ratio for time masking max_n_time_masks (int): maximum number of time masking replace_with_zero (bool): pad zero on mask if true else use mean + warp_mode (str): "PIL" (default, fast, not differentiable) + or "sparse_image_warp" (slow, differentiable) """ super().__init__() self._rng = rng + self.inplace = True self.replace_with_zero = replace_with_zero + self.mode = warp_mode self.W = W self.F = F self.T = T @@ -126,24 +135,80 @@ class SpecAugmentor(AugmentorBase): def __repr__(self): return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" - def time_warp(xs, W=40): - return xs + def time_warp(self, x, mode='PIL'): + """time warp for spec augment + move random center frame by the random width ~ uniform(-window, window) + + Args: + x (np.ndarray): spectrogram (time, freq) + mode (str): PIL or sparse_image_warp + + Raises: + NotImplementedError: [description] + NotImplementedError: [description] + + Returns: + np.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp = self.W + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), + BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), + BICUBIC) + if self.inplace: + x[:warped] = left + x[warped:] = right + return x + return np.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + raise NotImplementedError('sparse_image_warp') + else: + raise NotImplementedError( + "unknown resize mode: " + mode + + ", choose one from (PIL, sparse_image_warp).") + + def mask_freq(self, x, replace_with_zero=False): + """freq mask - def mask_freq(self, xs, replace_with_zero=False): - n_bins = xs.shape[0] + Args: + x (np.ndarray): spectrogram (time, freq) + replace_with_zero (bool, optional): Defaults to False. + + Returns: + np.ndarray: freq mask spectrogram (time, freq) + """ + n_bins = x.shape[1] for i in range(0, self.n_freq_masks): f = int(self._rng.uniform(low=0, high=self.F)) f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) assert f_0 <= f_0 + f - if self.replace_with_zero: - xs[f_0:f_0 + f, :] = 0 + if replace_with_zero: + x[:, f_0:f_0 + f] = 0 else: - xs[f_0:f_0 + f, :] = xs.mean() + x[:, f_0:f_0 + f] = x.mean() self._freq_mask = (f_0, f_0 + f) - return xs + return x + + def mask_time(self, x, replace_with_zero=False): + """time mask - def mask_time(self, xs, replace_with_zero=False): - n_frames = xs.shape[1] + Args: + x (np.ndarray): spectrogram (time, freq) + replace_with_zero (bool, optional): Defaults to False. + + Returns: + np.ndarray: time mask spectrogram (time, freq) + """ + n_frames = x.shape[0] if self.adaptive_number_ratio > 0: n_masks = int(n_frames * self.adaptive_number_ratio) @@ -161,26 +226,26 @@ class SpecAugmentor(AugmentorBase): t = min(t, int(n_frames * self.p)) t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) assert t_0 <= t_0 + t - if self.replace_with_zero: - xs[:, t_0:t_0 + t] = 0 + if replace_with_zero: + x[t_0:t_0 + t, :] = 0 else: - xs[:, t_0:t_0 + t] = xs.mean() + x[t_0:t_0 + t, :] = x.mean() self._time_mask = (t_0, t_0 + t) - return xs + return x def __call__(self, x, train=True): if not train: return x return self.transform_feature(x) - def transform_feature(self, xs: np.ndarray): + def transform_feature(self, x: np.ndarray): """ Args: - xs (FloatTensor): `[F, T]` + x (np.ndarray): `[T, F]` Returns: - xs (FloatTensor): `[F, T]` + x (np.ndarray): `[T, F]` """ - xs = self.time_warp(xs) - xs = self.mask_freq(xs) - xs = self.mask_time(xs) - return xs + x = self.time_warp(x, self.mode) + x = self.mask_freq(x, self.replace_with_zero) + x = self.mask_time(x, self.replace_with_zero) + return x diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 11c1fa2d..7e9efa36 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -167,32 +167,6 @@ class AudioFeaturizer(object): raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) - def _compute_linear_specgram(self, - samples, - sample_rate, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - eps=1e-14): - """Compute the linear spectrogram from FFT energy.""" - if max_freq is None: - max_freq = sample_rate / 2 - if max_freq > sample_rate / 2: - raise ValueError("max_freq must not be greater than half of " - "sample rate.") - if stride_ms > window_ms: - raise ValueError("Stride size must not be greater than " - "window size.") - stride_size = int(0.001 * sample_rate * stride_ms) - window_size = int(0.001 * sample_rate * window_ms) - specgram, freqs = self._specgram_real( - samples, - window_size=window_size, - stride_size=stride_size, - sample_rate=sample_rate) - ind = np.where(freqs <= max_freq)[0][-1] + 1 - return np.log(specgram[:ind, :] + eps) - def _specgram_real(self, samples, window_size, stride_size, sample_rate): """Compute the spectrogram for samples from a real signal.""" # extract strided windows @@ -217,26 +191,65 @@ class AudioFeaturizer(object): freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + def _compute_linear_specgram(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """Compute the linear spectrogram from FFT energy. + + Args: + samples ([type]): [description] + sample_rate ([type]): [description] + stride_ms (float, optional): [description]. Defaults to 10.0. + window_ms (float, optional): [description]. Defaults to 20.0. + max_freq ([type], optional): [description]. Defaults to None. + eps ([type], optional): [description]. Defaults to 1e-14. + + Raises: + ValueError: [description] + ValueError: [description] + + Returns: + np.ndarray: log spectrogram, (time, freq) + """ + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must not be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + specgram, freqs = self._specgram_real( + samples, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + # (freq, time) + spec = np.log(specgram[:ind, :] + eps) + return np.transpose(spec) + def _concat_delta_delta(self, feat): """append delat, delta-delta feature. Args: - feat (np.ndarray): (D, T) + feat (np.ndarray): (T, D) Returns: - np.ndarray: feat with delta-delta, (3*D, T) + np.ndarray: feat with delta-delta, (T, 3*D) """ - feat = np.transpose(feat) # Deltas d_feat = delta(feat, 2) # Deltas-Deltas dd_feat = delta(feat, 2) - # transpose - feat = np.transpose(feat) - d_feat = np.transpose(d_feat) - dd_feat = np.transpose(dd_feat) # concat above three features - concat_feat = np.concatenate((feat, d_feat, dd_feat)) + concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1) return concat_feat def _compute_mfcc(self, @@ -292,7 +305,6 @@ class AudioFeaturizer(object): ceplifter=22, useEnergy=True, winfunc='povey') - mfcc_feat = np.transpose(mfcc_feat) if delta_delta: mfcc_feat = self._concat_delta_delta(mfcc_feat) return mfcc_feat @@ -346,8 +358,6 @@ class AudioFeaturizer(object): remove_dc_offset=True, preemph=0.97, wintype='povey') - - fbank_feat = np.transpose(fbank_feat) if delta_delta: fbank_feat = self._concat_delta_delta(fbank_feat) return fbank_feat diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 287b51e5..73b3a4ba 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -40,21 +40,21 @@ class CollateFunc(object): number = 0 for item in batch: audioseg = AudioSegment.from_file(item['feat']) - feat = self.feature_func(audioseg) #(D, T) + feat = self.feature_func(audioseg) #(T, D) - sums = np.sum(feat, axis=1) + sums = np.sum(feat, axis=0) if mean_stat is None: mean_stat = sums else: mean_stat += sums - square_sums = np.sum(np.square(feat), axis=1) + square_sums = np.sum(np.square(feat), axis=0) if var_stat is None: var_stat = square_sums else: var_stat += square_sums - number += feat.shape[1] + number += feat.shape[0] return number, mean_stat, var_stat @@ -120,7 +120,7 @@ class FeatureNormalizer(object): """Normalize features to be of zero mean and unit stddev. :param features: Input features to be normalized. - :type features: ndarray, shape (D, T) + :type features: ndarray, shape (T, D) :param eps: added to stddev to provide numerical stablibity. :type eps: float :return: Normalized features. @@ -131,8 +131,8 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" mean, istd = load_cmvn(filepath, filetype='json') - self._mean = np.expand_dims(mean, axis=-1) - self._istd = np.expand_dims(istd, axis=-1) + self._mean = np.expand_dims(mean, axis=0) + self._istd = np.expand_dims(istd, axis=0) def write_to_file(self, filepath): """Write the mean and stddev to the file. diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 4900350e..df300479 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -242,7 +242,6 @@ class SpeechCollator(): # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) - specgram = specgram.transpose([1, 0]) return specgram, transcript_part def __call__(self, batch): @@ -250,7 +249,7 @@ class SpeechCollator(): Args: batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (D, T) + audio (np.ndarray) shape (T, D) text (List[int] or str): shape (U,) Returns: diff --git a/deepspeech/io/collator_st.py b/deepspeech/io/collator_st.py index 1ee36190..28573366 100644 --- a/deepspeech/io/collator_st.py +++ b/deepspeech/io/collator_st.py @@ -217,6 +217,34 @@ class SpeechCollator(): return self._local_data.tar2object[tarpath].extractfile( self._local_data.tar2info[tarpath][filename]) + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + return self._speech_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._speech_featurizer.text_feature + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + @property + def stride_ms(self): + return self._speech_featurizer.stride_ms + def process_utterance(self, audio_file, translation): """Load, augment, featurize and normalize for speech data. @@ -244,7 +272,6 @@ class SpeechCollator(): # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) - specgram = specgram.transpose([1, 0]) return specgram, translation_part def __call__(self, batch): @@ -252,7 +279,7 @@ class SpeechCollator(): Args: batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (D, T) + audio (np.ndarray) shape (T, D) text (List[int] or str): shape (U,) Returns: @@ -296,34 +323,6 @@ class SpeechCollator(): text_lens = np.array(text_lens).astype(np.int64) return utts, padded_audios, audio_lens, padded_texts, text_lens - @property - def manifest(self): - return self._manifest - - @property - def vocab_size(self): - return self._speech_featurizer.vocab_size - - @property - def vocab_list(self): - return self._speech_featurizer.vocab_list - - @property - def vocab_dict(self): - return self._speech_featurizer.vocab_dict - - @property - def text_feature(self): - return self._speech_featurizer.text_feature - - @property - def feature_size(self): - return self._speech_featurizer.feature_size - - @property - def stride_ms(self): - return self._speech_featurizer.stride_ms - class TripletSpeechCollator(SpeechCollator): def process_utterance(self, audio_file, translation, transcript): @@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator): # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) - specgram = specgram.transpose([1, 0]) return specgram, translation_part, transcript_part def __call__(self, batch): @@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator): Args: batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (D, T) + audio (np.ndarray) shape (T, D) text (List[int] or str): shape (U,) Returns: @@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator): :rtype: tuple of (2darray, list) """ specgram = kaldiio.load_mat(audio_file) - specgram = specgram.transpose([1, 0]) assert specgram.shape[ - 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( - self._feat_dim, specgram.shape[0]) + 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[1]) # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) - specgram = specgram.transpose([1, 0]) if self._keep_transcription_text: return specgram, translation else: text_ids = self._text_featurizer.featurize(translation) return specgram, text_ids - @property - def manifest(self): - return self._manifest - - @property - def vocab_size(self): - return self._text_featurizer.vocab_size - - @property - def vocab_list(self): - return self._text_featurizer.vocab_list - - @property - def vocab_dict(self): - return self._text_featurizer.vocab_dict - - @property - def text_feature(self): - return self._text_featurizer - - @property - def feature_size(self): - return self._feat_dim - - @property - def stride_ms(self): - return self._stride_ms - class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): def process_utterance(self, audio_file, translation, transcript): @@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): :rtype: tuple of (2darray, (list, list)) """ specgram = kaldiio.load_mat(audio_file) - specgram = specgram.transpose([1, 0]) assert specgram.shape[ - 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( - self._feat_dim, specgram.shape[0]) + 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[1]) # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) - specgram = specgram.transpose([1, 0]) if self._keep_transcription_text: return specgram, translation, transcript else: @@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): Args: batch ([List]): batch is (audio, text) - audio (np.ndarray) shape (D, T) + audio (np.ndarray) shape (T, D) translation (List[int] or str): shape (U,) transcription (List[int] or str): shape (V,) diff --git a/requirements.txt b/requirements.txt index af2600e0..08f2f258 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ coverage gpustat kaldiio +Pillow pre-commit pybind11 resampy==0.2.2 From aab02997f920d543c0ecf36b81bfaa032f46186f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 23 Aug 2021 06:47:36 +0000 Subject: [PATCH 05/17] fix specaug config --- deepspeech/frontend/augmentor/spec_augment.py | 2 ++ examples/aishell/s0/conf/augmentation.json | 11 ++++++----- examples/aishell/s1/conf/augmentation.json | 3 ++- examples/augmentation/augmentation.json | 9 +++++---- examples/librispeech/s0/conf/augmentation.json | 3 ++- examples/librispeech/s1/conf/augmentation.json | 3 ++- examples/librispeech/s2/conf/augmentation.json | 3 ++- examples/timit/s1/conf/augmentation.json | 3 ++- examples/tiny/s0/conf/augmentation.json | 12 +++++++----- examples/tiny/s1/conf/augmentation.json | 3 ++- 10 files changed, 32 insertions(+), 20 deletions(-) diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index a3f4e268..7c23b628 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -245,6 +245,8 @@ class SpecAugmentor(AugmentorBase): Returns: x (np.ndarray): `[T, F]` """ + assert isinstance(x, np.ndarray) + assert x.ndim == 2 x = self.time_warp(x, self.mode) x = self.mask_freq(x, self.replace_with_zero) x = self.mask_time(x, self.replace_with_zero) diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 81d110b0..39afe4e6 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -19,17 +19,18 @@ { "type": "specaug", "params": { - "F": 10, - "T": 50, + "W": 5, + "warp_mode": "PIL", + "F": 30, "n_freq_masks": 2, + "T": 40, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": false }, "prob": 1.0 } -] +] \ No newline at end of file diff --git a/examples/aishell/s1/conf/augmentation.json b/examples/aishell/s1/conf/augmentation.json index 81d110b0..d0409b14 100644 --- a/examples/aishell/s1/conf/augmentation.json +++ b/examples/aishell/s1/conf/augmentation.json @@ -28,7 +28,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/augmentation/augmentation.json b/examples/augmentation/augmentation.json index baf2cac3..c99299d6 100644 --- a/examples/augmentation/augmentation.json +++ b/examples/augmentation/augmentation.json @@ -52,17 +52,18 @@ { "type": "specaug", "params": { + "W": 80, + "warp_mode": "PIL", "F": 10, - "T": 50, "n_freq_masks": 2, + "T": 50, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": false }, - "prob": 0.0 + "prob": 1.0 } ] diff --git a/examples/librispeech/s0/conf/augmentation.json b/examples/librispeech/s0/conf/augmentation.json index 81d110b0..d0409b14 100644 --- a/examples/librispeech/s0/conf/augmentation.json +++ b/examples/librispeech/s0/conf/augmentation.json @@ -28,7 +28,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/librispeech/s1/conf/augmentation.json b/examples/librispeech/s1/conf/augmentation.json index 7dd158eb..8e6e9704 100644 --- a/examples/librispeech/s1/conf/augmentation.json +++ b/examples/librispeech/s1/conf/augmentation.json @@ -28,7 +28,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/librispeech/s2/conf/augmentation.json b/examples/librispeech/s2/conf/augmentation.json index cc8c7e00..e20fc199 100644 --- a/examples/librispeech/s2/conf/augmentation.json +++ b/examples/librispeech/s2/conf/augmentation.json @@ -11,7 +11,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/timit/s1/conf/augmentation.json b/examples/timit/s1/conf/augmentation.json index 7dd158eb..8e6e9704 100644 --- a/examples/timit/s1/conf/augmentation.json +++ b/examples/timit/s1/conf/augmentation.json @@ -28,7 +28,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/tiny/s0/conf/augmentation.json b/examples/tiny/s0/conf/augmentation.json index 8f9ff7fd..83705516 100644 --- a/examples/tiny/s0/conf/augmentation.json +++ b/examples/tiny/s0/conf/augmentation.json @@ -6,7 +6,7 @@ "max_speed_rate": 1.1, "num_rates": 3 }, - "prob": 1.0 + "prob": 0.0 }, { "type": "shift", @@ -19,16 +19,18 @@ { "type": "specaug", "params": { - "F": 10, - "T": 50, + "W": 5, + "warp_mode": "PIL", + "F": 30, "n_freq_masks": 2, + "T": 40, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } diff --git a/examples/tiny/s1/conf/augmentation.json b/examples/tiny/s1/conf/augmentation.json index 8f9ff7fd..6010c2e4 100644 --- a/examples/tiny/s1/conf/augmentation.json +++ b/examples/tiny/s1/conf/augmentation.json @@ -28,7 +28,8 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true, + "warp_mode": "PIL" }, "prob": 1.0 } From 27daa92a81132d26bbf78ae79469e1a8cc27ef16 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 23 Aug 2021 07:00:26 +0000 Subject: [PATCH 06/17] using to_static --- deepspeech/exps/u2_kaldi/bin/test.py | 1 + deepspeech/models/u2_st.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index 457672c0..c5064ec5 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -13,6 +13,7 @@ # limitations under the License. """Evaluation for U2 model.""" import cProfile +from yacs.config import CfgNode from deepspeech.training.cli import default_argument_parser from deepspeech.utils.dynamic_import import dynamic_import diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index 99420a89..b725cc35 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer): best_hyps = best_hyps[:, 1:] return best_hyps - @jit.export + @jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.export + @jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.export + @jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.export + @jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ return self.eos - @jit.export + @jit.to_static def forward_encoder_chunk( self, xs: paddle.Tensor, @@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.export + @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer): """ return self.ctc.log_softmax(xs) - @jit.export + @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor, From 2c75c923b9a64f1a6e5c92babbfc71f693abf1af Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 23 Aug 2021 08:07:35 +0000 Subject: [PATCH 07/17] fix_mfa --- deepspeech/decoders/swig/setup.py | 5 +- examples/aishell/s0/conf/augmentation.json | 2 +- examples/thchs30/a0/local/data.sh | 38 +++++++------ examples/thchs30/a0/local/gen_word2phone.py | 56 +++++++++++++------ .../thchs30/a0/local/reorganize_thchs30.py | 9 +-- examples/thchs30/a0/run.sh | 15 +++-- tools/extras/install_mfa.sh | 2 +- 7 files changed, 79 insertions(+), 48 deletions(-) diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 86af475a..3da5ce8b 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -84,8 +84,9 @@ FILES = glob.glob('kenlm/util/*.cc') \ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES = [ - fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') - or fn.endswith('unittest.cc')) + fn for fn in FILES + if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( + 'unittest.cc')) ] LIBS = ['stdc++'] diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 39afe4e6..ac8a1c53 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -33,4 +33,4 @@ }, "prob": 1.0 } -] \ No newline at end of file +] diff --git a/examples/thchs30/a0/local/data.sh b/examples/thchs30/a0/local/data.sh index 169367ac..8614a041 100644 --- a/examples/thchs30/a0/local/data.sh +++ b/examples/thchs30/a0/local/data.sh @@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then echo "Prepare THCHS-30 failed. Terminated." exit 1 fi - fi -# dump manifest to data/ -python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data - -# copy files to data/dict to gen word.lexicon -cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1 -cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2 +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # dump manifest to data/ + python3 ${MAIN_ROOT}/utils/dump_manifest.py --manifest-path=data/manifest.train --output-dir=data +fi -# copy phone.lexicon to data/dict -cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # copy files to data/dict to gen word.lexicon + cp ${TARGET_DIR}/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1 + cp ${TARGET_DIR}/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2 + # copy phone.lexicon to data/dict + cp ${TARGET_DIR}/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon +fi -# gen word.lexicon -python local/gen_word2phone.py --root-dir=data/dict --output-dir=data/dict +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # gen word.lexicon + python local/gen_word2phone.py --lexicon-files="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2" --output-path=data/dict/word.lexicon +fi -# reorganize dataset for MFA -if [ ! -d $EXP_DIR/thchs30_corpus ]; then - echo "reorganizing thchs30 corpus..." - python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME - echo "reorganization done." +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # reorganize dataset for MFA + if [ ! -d $EXP_DIR/thchs30_corpus ]; then + echo "reorganizing thchs30 corpus..." + python local/reorganize_thchs30.py --root-dir=data --output-dir=data/thchs30_corpus --script-type=$LEXICON_NAME + echo "reorganization done." + fi fi echo "THCHS-30 data preparation done." diff --git a/examples/thchs30/a0/local/gen_word2phone.py b/examples/thchs30/a0/local/gen_word2phone.py index cd584fcd..9bc0249b 100644 --- a/examples/thchs30/a0/local/gen_word2phone.py +++ b/examples/thchs30/a0/local/gen_word2phone.py @@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt import argparse from collections import defaultdict from pathlib import Path +from typing import List from typing import Union # key: (cn, ('ee', 'er4')),value: count @@ -34,7 +35,7 @@ def is_Chinese(ch): return False -def proc_line(line): +def proc_line(line: str): line = line.strip() if is_Chinese(line[0]): line_list = line.split() @@ -49,20 +50,25 @@ def proc_line(line): cn_phones_counter[(cn, phones)] += 1 -def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]): - root_dir = Path(root_dir).expanduser() - output_dir = Path(output_dir).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - file1 = root_dir / "lm_word_lexicon_1" - file2 = root_dir / "lm_word_lexicon_2" - write_file = output_dir / "word.lexicon" +""" +example lines of output +the first column is a Chinese character +the second is the probability of this pronunciation +and the rest are the phones of this pronunciation +一 0.22 ii i1↩ +一 0.45 ii i4↩ +一 0.32 ii i2↩ +一 0.01 ii i5 +""" + + +def gen_lexicon(lexicon_files: List[Union[str, Path]], + output_path: Union[str, Path]): + for file_path in lexicon_files: + with open(file_path, "r") as f1: + for line in f1: + proc_line(line) - with open(file1, "r") as f1: - for line in f1: - proc_line(line) - with open(file2, "r") as f2: - for line in f2: - proc_line(line) for key in cn_phones_counter: cn = key[0] cn_counter[cn].append((key[1], cn_phones_counter[key])) @@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]): p = round(p, 2) if p > 0: cn_counter_p[key].append((item[0], p)) - with open(write_file, "w") as wf: + + with open(output_path, "w") as wf: for key in cn_counter_p: phone_p_list = cn_counter_p[key] for item in phone_p_list: @@ -87,8 +94,21 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Gen Chinese characters to phone lexicon for THCHS-30 dataset" ) + # A line of word_lexicon: + # 一丁点 ii i4 d ing1 d ian3 + # the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len + parser.add_argument( + "--lexicon-files", + type=str, + default="data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2", + help="lm_word_lexicon files") parser.add_argument( - "--root-dir", type=str, help="dir to thchs30 lm_word_lexicons") - parser.add_argument("--output-dir", type=str, help="path to save outputs") + "--output-path", + type=str, + default="data/dict/word.lexicon", + help="path to save output word2phone lexicon") args = parser.parse_args() - gen_lexicon(args.root_dir, args.output_dir) + lexicon_files = args.lexicon_files.split(" ") + output_path = Path(args.output_path).expanduser() + + gen_lexicon(lexicon_files, output_path) diff --git a/examples/thchs30/a0/local/reorganize_thchs30.py b/examples/thchs30/a0/local/reorganize_thchs30.py index 9df6bc6a..c7c6248b 100644 --- a/examples/thchs30/a0/local/reorganize_thchs30.py +++ b/examples/thchs30/a0/local/reorganize_thchs30.py @@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path], def reorganize_thchs30(root_dir: Union[str, Path], output_dir: Union[str, Path]=None, script_type='phone'): - root_dir = Path(root_dir).expanduser() - output_dir = Path(output_dir).expanduser() output_dir.mkdir(parents=True, exist_ok=True) link_wav(root_dir, output_dir) write_lab(root_dir, output_dir, script_type) @@ -72,12 +70,15 @@ if __name__ == "__main__": parser.add_argument( "--output-dir", type=str, - help="path to save outputs(audio and transcriptions)") + help="path to save outputs (audio and transcriptions)") parser.add_argument( "--script-type", type=str, default="phone", help="type of lab ('word'/'syllable'/'phone')") + args = parser.parse_args() - reorganize_thchs30(args.root_dir, args.output_dir, args.script_type) + root_dir = Path(args.root_dir).expanduser() + output_dir = Path(args.output_dir).expanduser() + reorganize_thchs30(root_dir, output_dir, args.script_type) diff --git a/examples/thchs30/a0/run.sh b/examples/thchs30/a0/run.sh index 53f96b37..5081b612 100755 --- a/examples/thchs30/a0/run.sh +++ b/examples/thchs30/a0/run.sh @@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; # gen lexicon relink gen dump if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # prepare data - bash ./local/data.sh $LEXICON_NAME|| exit -1 + echo "Start prepare thchs30 data for MFA ..." + bash ./local/data.sh $LEXICON_NAME || exit -1 fi -# run MFA -if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then - echo "Start MFA training..." - mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS - echo "training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n" +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # run MFA + if [ ! -d "$EXP_DIR/thchs30_alignment" ]; then + echo "Start MFA training ..." + mfa_train_and_align data/thchs30_corpus data/dict/$LEXICON_NAME.lexicon $EXP_DIR/thchs30_alignment -o $EXP_DIR/thchs30_model --clean --verbose --temp_directory exp/.mfa_train_and_align --num_jobs $NUM_JOBS + echo "MFA training done! \nresults: $EXP_DIR/thchs30_alignment \nmodel: $EXP_DIR/thchs30_model\n" + fi fi diff --git a/tools/extras/install_mfa.sh b/tools/extras/install_mfa.sh index b0a4cf99..ae126fa6 100755 --- a/tools/extras/install_mfa.sh +++ b/tools/extras/install_mfa.sh @@ -4,7 +4,7 @@ test -d Montreal-Forced-Aligner || git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git -pushd Montreal-Forced-Aligner && git checkout v2.0.0a7 && python setup.py install +pushd Montreal-Forced-Aligner && python setup.py install && popd test -d kaldi || { echo "need install kaldi first"; exit 1;} From 561d5cf085b49baf27b47f13b15074d654acbce2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 23 Aug 2021 12:11:43 +0000 Subject: [PATCH 08/17] refactor feature, dict and argument for new config format --- .flake8 | 4 + deepspeech/exps/deepspeech2/bin/export.py | 3 + deepspeech/exps/deepspeech2/bin/test.py | 3 + deepspeech/exps/u2/bin/alignment.py | 3 + deepspeech/exps/u2/bin/export.py | 3 + deepspeech/exps/u2/bin/test.py | 3 + deepspeech/exps/u2_kaldi/bin/test.py | 9 +++ deepspeech/exps/u2_kaldi/model.py | 32 +++++--- deepspeech/exps/u2_st/bin/export.py | 3 + deepspeech/exps/u2_st/bin/test.py | 3 + deepspeech/frontend/featurizer/__init__.py | 3 + .../frontend/featurizer/audio_featurizer.py | 2 +- .../frontend/featurizer/speech_featurizer.py | 2 +- .../frontend/featurizer/text_featurizer.py | 73 +++++++------------ deepspeech/frontend/utility.py | 50 ++++++++++--- deepspeech/training/cli.py | 7 -- examples/aishell/s0/conf/augmentation.json | 2 +- examples/librispeech/s2/conf/transformer.yaml | 10 +-- examples/librispeech/s2/local/align.sh | 13 ++-- examples/librispeech/s2/local/export.sh | 3 +- examples/librispeech/s2/local/test.sh | 19 +++-- examples/librispeech/s2/run.sh | 5 +- examples/tiny/s0/conf/augmentation.json | 3 +- 23 files changed, 158 insertions(+), 100 deletions(-) diff --git a/.flake8 b/.flake8 index 72289943..44685f23 100644 --- a/.flake8 +++ b/.flake8 @@ -42,6 +42,10 @@ ignore = # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 + +per-file-ignores = + */__init__.py: F401 + # Specify the list of error codes you wish Flake8 to report. select = E, diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py index f8764fde..7962d4fc 100644 --- a/deepspeech/exps/deepspeech2/bin/export.py +++ b/deepspeech/exps/deepspeech2/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") parser.add_argument("--model_type") args = parser.parse_args() if args.model_type is None: diff --git a/deepspeech/exps/deepspeech2/bin/test.py b/deepspeech/exps/deepspeech2/bin/test.py index 376e18e3..f2fd3a39 100644 --- a/deepspeech/exps/deepspeech2/bin/test.py +++ b/deepspeech/exps/deepspeech2/bin/test.py @@ -31,6 +31,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() parser.add_argument("--model_type") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) if args.model_type is None: diff --git a/deepspeech/exps/u2/bin/alignment.py b/deepspeech/exps/u2/bin/alignment.py index c1c9582f..cef9d1ab 100644 --- a/deepspeech/exps/u2/bin/alignment.py +++ b/deepspeech/exps/u2/bin/alignment.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/export.py b/deepspeech/exps/u2/bin/export.py index 292c7838..3dc41b70 100644 --- a/deepspeech/exps/u2/bin/export.py +++ b/deepspeech/exps/u2/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py index c47f932c..f6127675 100644 --- a/deepspeech/exps/u2/bin/test.py +++ b/deepspeech/exps/u2/bin/test.py @@ -34,6 +34,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index c5064ec5..93a29ab1 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -13,6 +13,7 @@ # limitations under the License. """Evaluation for U2 model.""" import cProfile + from yacs.config import CfgNode from deepspeech.training.cli import default_argument_parser @@ -54,6 +55,14 @@ if __name__ == "__main__": type=str, default='test', help='run mode, e.g. test, align, export') + parser.add_argument( + '--dict-path', type=str, default=None, help='dict path.') + # save asr result to + parser.add_argument( + "--result-file", type=str, help="path of save the asr result") + # save jit model to + parser.add_argument( + "--export-path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 60f070a3..4f6ff4cb 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -25,6 +25,8 @@ import paddle from paddle import distributed as dist from yacs.config import CfgNode +from deepspeech.frontend.featurizer import TextFeaturizer +from deepspeech.frontend.utility import load_dict from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory @@ -80,8 +82,8 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) # loss div by `batch_size * accum_grad` @@ -124,6 +126,7 @@ class U2Trainer(Trainer): valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, @@ -305,10 +308,8 @@ class U2Trainer(Trainer): model_conf.output_dim = self.train_loader.vocab_size model_conf.freeze() model = U2Model.from_config(model_conf) - if self.parallel: model = paddle.DataParallel(model) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) @@ -379,13 +380,13 @@ class U2Tester(U2Trainer): def __init__(self, config, args): super().__init__(config, args) - def ordid2token(self, texts, texts_len): + def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ trans = [] for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) + trans.append(text_feature.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, @@ -401,8 +402,11 @@ class U2Tester(U2Trainer): error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer start_time = time.time() - text_feature = self.test_loader.collate_fn.text_feature - target_transcripts = self.ordid2token(texts, texts_len) + text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) + target_transcripts = self.id2token(texts, texts_len, text_feature) result_transcripts = self.model.decode( audio, audio_len, @@ -450,7 +454,7 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") - stride_ms = self.test_loader.collate_fn.stride_ms + stride_ms = self.config.collator.stride_ms error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 @@ -525,8 +529,9 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - stride_ms = self.config.collate.stride_ms - token_dict = self.align_loader.collate_fn.vocab_list + stride_ms = self.config.collater.stride_ms + token_dict = self.args.char_list + with open(self.args.result_file, 'w') as fout: # one example in batch for i, batch in enumerate(self.align_loader): @@ -613,6 +618,11 @@ class U2Tester(U2Trainer): except KeyboardInterrupt: sys.exit(-1) + def setup_dict(self): + # load dictionary for debug log + self.args.char_list = load_dict(self.args.dict_path, + "maskctc" in self.args.model_name) + def setup(self): """Setup the experiment. """ @@ -624,6 +634,8 @@ class U2Tester(U2Trainer): self.setup_dataloader() self.setup_model() + self.setup_dict() + self.iteration = 0 self.epoch = 0 diff --git a/deepspeech/exps/u2_st/bin/export.py b/deepspeech/exps/u2_st/bin/export.py index f566ba5b..c7eb5d03 100644 --- a/deepspeech/exps/u2_st/bin/export.py +++ b/deepspeech/exps/u2_st/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2_st/bin/test.py b/deepspeech/exps/u2_st/bin/test.py index d66c7a26..81197dec 100644 --- a/deepspeech/exps/u2_st/bin/test.py +++ b/deepspeech/exps/u2_st/bin/test.py @@ -34,6 +34,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/frontend/featurizer/__init__.py b/deepspeech/frontend/featurizer/__init__.py index 185a92b8..6992700d 100644 --- a/deepspeech/frontend/featurizer/__init__.py +++ b/deepspeech/frontend/featurizer/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .audio_featurizer import AudioFeaturizer #noqa: F401 +from .speech_featurizer import SpeechFeaturizer +from .text_featurizer import TextFeaturizer diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 7e9efa36..4c40c847 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -18,7 +18,7 @@ from python_speech_features import logfbank from python_speech_features import mfcc -class AudioFeaturizer(object): +class AudioFeaturizer(): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 0fbbc564..5082850d 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer -class SpeechFeaturizer(object): +class SpeechFeaturizer(): """Speech featurizer, for extracting features from both audio and transcript contents of SpeechSegment. diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 1ba6ac7f..e4364f70 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -14,12 +14,19 @@ """Contains the text featurizer class.""" import sentencepiece as spm -from deepspeech.frontend.utility import EOS -from deepspeech.frontend.utility import UNK +from ..utility import EOS +from ..utility import load_dict +from ..utility import UNK +__all__ = ["TextFeaturizer"] -class TextFeaturizer(object): - def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None): + +class TextFeaturizer(): + def __init__(self, + unit_type, + vocab_filepath, + spm_model_prefix=None, + maskctc=False): """Text featurizer, for processing or extracting features from text. Currently, it supports char/word/sentence-piece level tokenizing and conversion into @@ -34,11 +41,12 @@ class TextFeaturizer(object): assert unit_type in ('char', 'spm', 'word') self.unit_type = unit_type self.unk = UNK + self.maskctc = maskctc + if vocab_filepath: - self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file( - vocab_filepath) - self.unk_id = self._vocab_list.index(self.unk) - self.eos_id = self._vocab_list.index(EOS) + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file( + vocab_filepath, maskctc) + self.vocab_size = len(self.vocab_list) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -67,7 +75,7 @@ class TextFeaturizer(object): """Convert text string to a list of token indices. Args: - text (str): Text to process. + text (str): Text. Returns: List[int]: List of token indices. @@ -75,8 +83,8 @@ class TextFeaturizer(object): tokens = self.tokenize(text) ids = [] for token in tokens: - token = token if token in self._vocab_dict else self.unk - ids.append(self._vocab_dict[token]) + token = token if token in self.vocab_dict else self.unk + ids.append(self.vocab_dict[token]) return ids def defeaturize(self, idxs): @@ -87,7 +95,7 @@ class TextFeaturizer(object): idxs (List[int]): List of token indices. Returns: - str: Text to process. + str: Text. """ tokens = [] for idx in idxs: @@ -97,33 +105,6 @@ class TextFeaturizer(object): text = self.detokenize(tokens) return text - @property - def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return len(self._vocab_list) - - @property - def vocab_list(self): - """Return the vocabulary in list. - - Returns: - List[str]: tokens. - """ - return self._vocab_list - - @property - def vocab_dict(self): - """Return the vocabulary in dict. - - Returns: - Dict[str, int]: token str -> int - """ - return self._vocab_dict - def char_tokenize(self, text): """Character tokenizer. @@ -206,14 +187,16 @@ class TextFeaturizer(object): return decode(tokens) - def _load_vocabulary_from_file(self, vocab_filepath): + def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool): """Load vocabulary from file.""" - vocab_lines = [] - with open(vocab_filepath, 'r', encoding='utf-8') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] + vocab_list = load_dict(vocab_filepath, maskctc) + assert vocab_list is not None + id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - return token2id, id2token, vocab_list + + unk_id = vocab_list.index(UNK) + eos_id = vocab_list.index(EOS) + return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index b2dd9601..3d0683b0 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -15,6 +15,9 @@ import codecs import json import math +from typing import List +from typing import Optional +from typing import Text import numpy as np @@ -23,16 +26,35 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = [ - "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", - "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK", - "BLANK" + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC" ] IGNORE_ID = -1 -SOS = "" +# `sos` and `eos` using same token +SOS = "" EOS = SOS UNK = "" BLANK = "" +MASKCTC = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + char_list = [entry.split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list def read_manifest( @@ -47,12 +69,20 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: IOError: If failed to parse the manifest. diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index b83d989d..9d145645 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -47,18 +47,11 @@ def default_argument_parser(): # data and output parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") # load from saved checkpoint parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - # save jit model to - parser.add_argument("--export_path", type=str, help="path of the jit model to save") - - # save asr result to - parser.add_argument("--result_file", type=str, help="path of save the asr result") - # running parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 39afe4e6..ac8a1c53 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -33,4 +33,4 @@ }, "prob": 1.0 } -] \ No newline at end of file +] diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index 7710d706..ded4f240 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -3,17 +3,11 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - min_input_len: 0.5 # second - max_input_len: 20.0 # second - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 collator: - vocab_filepath: data/vocab.txt + vocab_filepath: data/train_960_unigram5000_units.txt unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' + spm_model_prefix: 'data/train_960_unigram5000' mean_std_filepath: "" augmentation_config: conf/augmentation.json batch_size: 64 diff --git a/examples/librispeech/s2/local/align.sh b/examples/librispeech/s2/local/align.sh index 94146ccf..b3d8fa5f 100755 --- a/examples/librispeech/s2/local/align.sh +++ b/examples/librispeech/s2/local/align.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" exit -1 fi @@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then device=cpu fi config_path=$1 -ckpt_prefix=$2 +dict_path=$2 +ckpt_prefix=$3 batch_size=1 output_dir=${ckpt_prefix} @@ -22,11 +23,13 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/test.py \ ---run_mode 'align' \ +--model-name 'u2_kaldi' \ +--run-mode 'align' \ +--dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ ---result_file ${output_dir}/${type}.align \ +--result-file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.batch_size ${batch_size} diff --git a/examples/librispeech/s2/local/export.sh b/examples/librispeech/s2/local/export.sh index 7e42e011..efa70a2b 100755 --- a/examples/librispeech/s2/local/export.sh +++ b/examples/librispeech/s2/local/export.sh @@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then fi python3 -u ${BIN_DIR}/test.py \ ---run_mode 'export' \ +--model-name 'u2_kaldi' \ +--run-mode 'export' \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 762211c2..efd06f35 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" exit -1 fi @@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then fi config_path=$1 -ckpt_prefix=$2 +dict_path=$2 +ckpt_prefix=$3 chunk_mode=false if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then @@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do batch_size=64 fi python3 -u ${BIN_DIR}/test.py \ - --run_mode test \ + --model-name u2_kaldi \ + --run-mode test \ + --dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ + --result-file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} @@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do echo "decoding ${type}" batch_size=1 python3 -u ${BIN_DIR}/test.py \ - --run_mode test \ + --model-name u2_kaldi \ + --run-mode test \ + --dict-path ${dict_path} \ --device ${device} \ --nproc 1 \ --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ + --result-file ${ckpt_prefix}.${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index def10ab0..26398dd1 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -5,6 +5,7 @@ source path.sh stage=0 stop_stage=100 conf_path=conf/transformer.yaml +dict_path=data/train_960_unigram5000_units.txt avg_num=5 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -29,12 +30,12 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # ctc alignment of test data - CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then diff --git a/examples/tiny/s0/conf/augmentation.json b/examples/tiny/s0/conf/augmentation.json index 83705516..4480307b 100644 --- a/examples/tiny/s0/conf/augmentation.json +++ b/examples/tiny/s0/conf/augmentation.json @@ -29,8 +29,7 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true, - "warp_mode": "PIL" + "replace_with_zero": true }, "prob": 1.0 } From 2d3b2aed05c36cc173dd8370dab986b1f0cf6513 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 23 Aug 2021 14:06:10 +0000 Subject: [PATCH 09/17] add seed in argparse --- deepspeech/exps/deepspeech2/bin/train.py | 4 --- deepspeech/exps/deepspeech2/model.py | 9 ------- deepspeech/training/cli.py | 33 +++++++++++++----------- deepspeech/training/trainer.py | 9 +++++++ examples/aishell/s0/local/train.sh | 12 ++++++++- examples/aishell/s1/local/train.sh | 12 ++++++++- examples/callcenter/s1/local/train.sh | 12 ++++++++- examples/librispeech/s0/local/train.sh | 12 ++++++++- examples/librispeech/s1/local/train.sh | 12 ++++++++- examples/librispeech/s2/local/train.sh | 12 ++++++++- examples/ted_en_zh/t0/local/train.sh | 12 ++++++++- examples/timit/s1/local/train.sh | 12 ++++++++- examples/tiny/s0/local/train.sh | 12 ++++++++- examples/tiny/s1/local/train.sh | 12 ++++++++- 14 files changed, 137 insertions(+), 38 deletions(-) diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index bb0bd43a..69ff043a 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer for DeepSpeech2 model.""" -import os - from paddle import distributed as dist from deepspeech.exps.deepspeech2.config import get_cfg_defaults @@ -55,7 +53,5 @@ if __name__ == "__main__": if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) - if config.training.seed is not None: - os.environ.setdefault('FLAGS_cudnn_deterministic', 'True') main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 1bd4c722..65c905a1 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 and DeepSpeech2Online model.""" -import random import time from collections import defaultdict from pathlib import Path @@ -54,7 +53,6 @@ class DeepSpeech2Trainer(Trainer): weight_decay=1e-6, # the coeff of weight decay global_grad_clip=5.0, # the global norm clip n_epoch=50, # train epochs - seed=1024, #train seed )) if config is not None: @@ -63,13 +61,6 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - if config.training.seed is not None: - self.set_seed(config.training.seed) - - def set_seed(self, seed): - np.random.seed(seed) - random.seed(seed) - paddle.seed(seed) def train_batch(self, batch_index, batch_data, msg): start = time.time() diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index b83d989d..d3b85355 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -16,23 +16,23 @@ import argparse def default_argument_parser(): r"""A simple yet genral argument parser for experiments with parakeet. - - This is used in examples with parakeet. And it is intended to be used by - other experiments with parakeet. It requires a minimal set of command line + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line arguments to start a training script. - - The ``--config`` and ``--opts`` are used for overwrite the deault + + The ``--config`` and ``--opts`` are used for overwrite the deault configuration. - - The ``--data`` and ``--output`` specifies the data path and output path. - Resuming training from existing progress at the output directory is the + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the intended default behavior. - + The ``--checkpoint_path`` specifies the checkpoint to load from. - + The ``--device`` and ``--nprocs`` specifies how to run the training. - - + + See Also -------- parakeet.training.experiment @@ -53,10 +53,10 @@ def default_argument_parser(): # load from saved checkpoint parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - # save jit model to + # save jit model to parser.add_argument("--export_path", type=str, help="path of the jit model to save") - # save asr result to + # save asr result to parser.add_argument("--result_file", type=str, help="path of save the asr result") # running @@ -65,10 +65,13 @@ def default_argument_parser(): parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") # overwrite extra config and default config - # parser.add_argument("--opts", nargs=argparse.REMAINDER, + # parser.add_argument("--opts", nargs=argparse.REMAINDER, # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") parser.add_argument("--opts", type=str, default=[], nargs='+', help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + + parser.add_argument("--seed", type=int, default=None, + help="seed to use for paddle, np and random. The default value is None") # yapd: enable return parser diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 209e2240..2ab7eac0 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random import time from pathlib import Path +import numpy as np import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter @@ -93,6 +95,13 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + if args.seed is not None: + self.set_seed(args.seed) + + def set_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def setup(self): """Setup the experiment. diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index c6a63180..d42e51fa 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -19,12 +19,22 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/callcenter/s1/local/train.sh b/examples/callcenter/s1/local/train.sh index f750568a..928c6492 100755 --- a/examples/callcenter/s1/local/train.sh +++ b/examples/callcenter/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index 039b9cea..dcd21df3 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -20,12 +20,22 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/librispeech/s2/local/train.sh b/examples/librispeech/s2/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/librispeech/s2/local/train.sh +++ b/examples/librispeech/s2/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/ted_en_zh/t0/local/train.sh b/examples/ted_en_zh/t0/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/ted_en_zh/t0/local/train.sh +++ b/examples/ted_en_zh/t0/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/timit/s1/local/train.sh b/examples/timit/s1/local/train.sh index f3eb98da..ec17054a 100755 --- a/examples/timit/s1/local/train.sh +++ b/examples/timit/s1/local/train.sh @@ -19,11 +19,21 @@ echo "using ${device}..." mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index c6a63180..d42e51fa 100755 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -19,12 +19,22 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---model_type ${model_type} +--model_type ${model_type} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index f6bd2c98..2fb3a95a 100755 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -18,11 +18,21 @@ fi mkdir -p exp +seed=1024 +if [ ${seed} ]; then + export FLAGS_cudnn_deterministic=True +fi + python3 -u ${BIN_DIR}/train.py \ --device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" From fd3491ba1b5a0ccfac7667a0bfc9048d1de792cc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 02:56:41 +0000 Subject: [PATCH 10/17] fix dataloader batchsize and minibatchsize --- deepspeech/exps/u2_kaldi/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 4f6ff4cb..46e5b4d9 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -228,7 +228,7 @@ class U2Trainer(Trainer): maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, - mini_batch_size=1, + mini_batch_size=self.args.nprocs, batch_count='auto', batch_bins=0, batch_frames_in=0, @@ -247,7 +247,7 @@ class U2Trainer(Trainer): maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, - mini_batch_size=1, + mini_batch_size=self.args.nprocs, batch_count='auto', batch_bins=0, batch_frames_in=0, @@ -263,7 +263,7 @@ class U2Trainer(Trainer): json_file=config.data.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.collator.batch_size, + batch_size=config.decoding.batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, @@ -282,7 +282,7 @@ class U2Trainer(Trainer): json_file=config.data.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.collator.batch_size, + batch_size=config.decoding.batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, From 8215bd0e7915fb8e139281c15056fe3fa39f01f9 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 05:46:39 +0000 Subject: [PATCH 11/17] fix load vocab; zero W for not warptime --- deepspeech/frontend/augmentor/spec_augment.py | 3 +++ deepspeech/frontend/utility.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 7c23b628..a9bb043d 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -151,6 +151,9 @@ class SpecAugmentor(AugmentorBase): np.ndarray: time warped spectrogram (time, freq) """ window = max_time_warp = self.W + if window == 0: + return x + if mode == "PIL": t = x.shape[0] if t - window <= window: diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 3d0683b0..72dfc98d 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -46,7 +46,7 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: with open(dict_path, "r") as f: dictionary = f.readlines() - char_list = [entry.split(" ")[0] for entry in dictionary] + char_list = [entry.strip().split(" ")[0] for entry in dictionary] if BLANK not in char_list: char_list.insert(0, BLANK) if EOS not in char_list: From 3d9aebfaa3373d9ee03ccf06f03bfcf07196c42c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 06:07:47 +0000 Subject: [PATCH 12/17] fix specaug; add data static --- examples/aishell/s0/README.md | 10 +++++++++- examples/aishell/s0/conf/augmentation.json | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 6ce39b23..eedf92c9 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -1,10 +1,18 @@ # Aishell-1 +## Data +| Data Subset | Duration in Seconds | +| data/manifest.train | 1.23 ~ 14.53125 | +| data/manifest.dev | 1.645 ~ 12.533 | +| data/manifest.test | 1.859125 ~ 14.6999375 | + +`jq '.feat_shape[0]' data/manifest.train | sort -un` + ## Deepspeech2 | Model | Params | Release | Config | Test set | Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382,0.073507 | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index ac8a1c53..6f249242 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -19,17 +19,17 @@ { "type": "specaug", "params": { - "W": 5, + "W": 0, "warp_mode": "PIL", - "F": 30, + "F": 10, "n_freq_masks": 2, - "T": 40, + "T": 50, "n_time_masks": 2, "p": 1.0, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": false + "replace_with_zero": true }, "prob": 1.0 } From 715e90a9dfb6fa3fab98d6b7c29a99b4570d789f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 07:42:52 +0000 Subject: [PATCH 13/17] fix librispeech s0 specaug --- examples/librispeech/s0/conf/augmentation.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/librispeech/s0/conf/augmentation.json b/examples/librispeech/s0/conf/augmentation.json index d0409b14..31c481c8 100644 --- a/examples/librispeech/s0/conf/augmentation.json +++ b/examples/librispeech/s0/conf/augmentation.json @@ -19,17 +19,17 @@ { "type": "specaug", "params": { + "W": 0, + "warp_mode": "PIL", "F": 10, - "T": 50, "n_freq_masks": 2, + "T": 50, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true, - "warp_mode": "PIL" + "replace_with_zero": true }, "prob": 1.0 } From d1db859657f99f373beec44876d652eec2d05983 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 09:14:40 +0000 Subject: [PATCH 14/17] fix dataloader pickle bugs --- deepspeech/decoders/swig/setup.py | 2 ++ deepspeech/exps/u2_kaldi/model.py | 3 -- deepspeech/frontend/augmentor/spec_augment.py | 2 +- deepspeech/io/converter.py | 2 +- deepspeech/io/dataloader.py | 24 +++++++++++---- deepspeech/io/dataset.py | 11 +++---- examples/aishell/s0/conf/augmentation.json | 2 +- .../librispeech/s2/conf/augmentation.json | 10 +++---- examples/librispeech/s2/conf/transformer.yaml | 29 +++++++++---------- 9 files changed, 47 insertions(+), 38 deletions(-) diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 3da5ce8b..8fb79296 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -83,11 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') +# yapf: disable FILES = [ fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( 'unittest.cc')) ] +# yapf: enable LIBS = ['stdc++'] if platform.system() != 'Darwin': diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 46e5b4d9..6a932d75 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -171,10 +171,7 @@ class U2Trainer(Trainer): if from_scratch: # save init model, i.e. 0 epoch self.save(tag='init') - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index a9bb043d..26c94d41 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -153,7 +153,7 @@ class SpecAugmentor(AugmentorBase): window = max_time_warp = self.W if window == 0: return x - + if mode == "PIL": t = x.shape[0] if t - window <= window: diff --git a/deepspeech/io/converter.py b/deepspeech/io/converter.py index 3bfcc1b1..b80c7b20 100644 --- a/deepspeech/io/converter.py +++ b/deepspeech/io/converter.py @@ -43,7 +43,7 @@ class CustomConverter(): batch (list): The batch to transform. Returns: - tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor) + tuple(np.ndarray, nn.ndarray, nn.ndarray) """ # batch should be located in list diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index 115fe461..a35a0bc0 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], return feat_dim, vocab_size +def batch_collate(x): + """de-tuple. + + Args: + x (List[Tuple]): [(utts, xs, ilens, ys, olens)] + + Returns: + Tuple: (utts, xs, ilens, ys, olens) + """ + return x[0] + + class BatchDataLoader(): def __init__(self, json_file: str, @@ -120,15 +132,15 @@ class BatchDataLoader(): # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list - self.dataset = TransformDataset( - self.minibaches, - lambda data: self.converter([self.reader(data, return_uttid=True)])) + self.dataset = TransformDataset(self.minibaches, self.converter, + self.reader) + self.dataloader = DataLoader( dataset=self.dataset, batch_size=1, - shuffle=not self.use_sortagrad if train_mode else False, - collate_fn=lambda x: x[0], - num_workers=n_iter_processes, ) + shuffle=not self.use_sortagrad if self.train_mode else False, + collate_fn=batch_collate, + num_workers=self.n_iter_processes, ) def __repr__(self): echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 74c08b46..d1fe0470 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -129,15 +129,16 @@ class TransformDataset(Dataset): Args: data: list object from make_batchset - transfrom: transform function - + converter: batch function + reader: read data """ - def __init__(self, data, transform): + def __init__(self, data, converter, reader): """Init function.""" super().__init__() self.data = data - self.transform = transform + self.converter = converter + self.reader = reader def __len__(self): """Len function.""" @@ -145,4 +146,4 @@ class TransformDataset(Dataset): def __getitem__(self, idx): """[] operator.""" - return self.transform(self.data[idx]) + return self.converter([self.reader(self.data[idx], return_uttid=True)]) diff --git a/examples/aishell/s0/conf/augmentation.json b/examples/aishell/s0/conf/augmentation.json index 6f249242..31c481c8 100644 --- a/examples/aishell/s0/conf/augmentation.json +++ b/examples/aishell/s0/conf/augmentation.json @@ -29,7 +29,7 @@ "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true + "replace_with_zero": true }, "prob": 1.0 } diff --git a/examples/librispeech/s2/conf/augmentation.json b/examples/librispeech/s2/conf/augmentation.json index e20fc199..3b14b9d0 100644 --- a/examples/librispeech/s2/conf/augmentation.json +++ b/examples/librispeech/s2/conf/augmentation.json @@ -2,17 +2,17 @@ { "type": "specaug", "params": { - "F": 10, - "T": 50, + "W": 5, + "warp_mode": "PIL", + "F": 30, "n_freq_masks": 2, + "T": 40, "n_time_masks": 2, "p": 1.0, - "W": 80, "adaptive_number_ratio": 0, "adaptive_size_ratio": 0, "max_n_time_masks": 20, - "replace_with_zero": true, - "warp_mode": "PIL" + "replace_with_zero": false }, "prob": 1.0 } diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index ded4f240..f7c27d1f 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -8,26 +8,23 @@ collator: vocab_filepath: data/train_960_unigram5000_units.txt unit_type: 'spm' spm_model_prefix: 'data/train_960_unigram5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 - raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank feat_dim: 83 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None stride_ms: 10.0 window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle + sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs + batch_size: 32 + maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced + maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced + minibatches: 0 # for debug + batch_count: auto + batch_bins: 0 + batch_frames_in: 0 + batch_frames_out: 0 + batch_frames_inout: 0 + augmentation_config: conf/augmentation.json num_workers: 2 + subsampling_factor: 1 + num_encs: 1 # network architecture From cfdca210ff243a45afa96a64c6ba42bf2586d5eb Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 24 Aug 2021 12:24:59 +0000 Subject: [PATCH 15/17] chaner style updater --- deepspeech/training/extensions/__init__.py | 28 +++ deepspeech/training/extensions/evaluator.py | 58 ++++++ deepspeech/training/extensions/extension.py | 41 ++++ deepspeech/training/extensions/snapshot.py | 102 ++++++++++ deepspeech/training/extensions/visualizer.py | 24 +++ deepspeech/training/reporter.py | 131 +++++++++++++ deepspeech/training/triggers/__init__.py | 13 ++ .../training/triggers/interval_trigger.py | 24 +++ deepspeech/training/triggers/limit_trigger.py | 17 ++ deepspeech/training/triggers/time_trigger.py | 17 ++ deepspeech/training/updaters/__init__.py | 0 .../training/updaters/standard_updater.py | 179 ++++++++++++++++++ deepspeech/training/updaters/trainer.py | 171 +++++++++++++++++ deepspeech/training/updaters/updater.py | 82 ++++++++ requirements.txt | 1 + 15 files changed, 888 insertions(+) create mode 100644 deepspeech/training/extensions/__init__.py create mode 100644 deepspeech/training/extensions/evaluator.py create mode 100644 deepspeech/training/extensions/extension.py create mode 100644 deepspeech/training/extensions/snapshot.py create mode 100644 deepspeech/training/extensions/visualizer.py create mode 100644 deepspeech/training/reporter.py create mode 100644 deepspeech/training/triggers/__init__.py create mode 100644 deepspeech/training/triggers/interval_trigger.py create mode 100644 deepspeech/training/triggers/limit_trigger.py create mode 100644 deepspeech/training/triggers/time_trigger.py create mode 100644 deepspeech/training/updaters/__init__.py create mode 100644 deepspeech/training/updaters/standard_updater.py create mode 100644 deepspeech/training/updaters/trainer.py create mode 100644 deepspeech/training/updaters/updater.py diff --git a/deepspeech/training/extensions/__init__.py b/deepspeech/training/extensions/__init__.py new file mode 100644 index 00000000..7ea7470e --- /dev/null +++ b/deepspeech/training/extensions/__init__.py @@ -0,0 +1,28 @@ + +from typing import Callable + +from .extension import Extension + +def make_extension(trigger: Callable=None, + default_name: str=None, + priority: int=None, + finalizer: Callable=None, + initializer: Callable=None, + on_error: Callable=None): + """Make an Extension-like object by injecting required attributes to it. + """ + if trigger is None: + trigger = Extension.trigger + if priority is None: + priority = Extension.priority + + def decorator(ext): + ext.trigger = trigger + ext.default_name = default_name or ext.__name__ + ext.priority = priority + ext.finalize = finalizer + ext.on_error = on_error + ext.initialize = initializer + return ext + + return decorator \ No newline at end of file diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py new file mode 100644 index 00000000..ffb7b3a2 --- /dev/null +++ b/deepspeech/training/extensions/evaluator.py @@ -0,0 +1,58 @@ +from typing import Dict + +import paddle +from paddle.io import DataLoader +from paddle.nn import Layer + +import extension +from ..reporter import DictSummary +from ..reporter import report +from ..reporter import scope + + +class StandardEvaluator(extension.Extension): + + trigger = (1, 'epoch') + default_name = 'validation' + priority = extension.PRIORITY_WRITER + + name = None + + def __init__(self, model: Layer, dataloader: DataLoader): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # dataloaders + self.dataloader = dataloader + + def evaluate_core(self, batch): + # compute + self.model(batch) # you may report here + + def evaluate(self): + # switch to eval mode + for model in self.models.values(): + model.eval() + + # to average evaluation metrics + summary = DictSummary() + for batch in self.dataloader: + observation = {} + with scope(observation): + # main evaluation computation here. + with paddle.no_grad(): + self.evaluate_core(batch) + summary.add(observation) + summary = summary.compute_mean() + return summary + + def __call__(self, trainer=None): + # evaluate and report the averaged metric to current observation + # if it is used to extend a trainer, the metrics is reported to + # to observation of the trainer + # or otherwise, you can use your own observation + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) \ No newline at end of file diff --git a/deepspeech/training/extensions/extension.py b/deepspeech/training/extensions/extension.py new file mode 100644 index 00000000..f8fcede3 --- /dev/null +++ b/deepspeech/training/extensions/extension.py @@ -0,0 +1,41 @@ +from typing import Callable + +PRIORITY_WRITER = 300 +PRIORITY_EDITOR = 200 +PRIORITY_READER = 100 + + +class Extension(): + """Extension to customize the behavior of Trainer.""" + trigger = (1, 'iteration') + priority = PRIORITY_READER + name = None + + @property + def default_name(self): + """Default name of the extension, class name by default.""" + return type(self).__name__ + + def __call__(self, trainer): + """Main action of the extention. After each update, it is executed + when the trigger fires.""" + raise NotImplementedError( + 'Extension implementation must override __call__.') + + def initialize(self, trainer): + """Action that is executed once to get the corect trainer state. + It is called before training normally, but if the trainer restores + states with an Snapshot extension, this method should also be called. + """ + pass + + def on_error(self, trainer, exc, tb): + """Handles the error raised during training before finalization. + """ + pass + + def finalize(self, trainer): + """Action that is executed when training is done. + For example, visualizers would need to be closed. + """ + pass \ No newline at end of file diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py new file mode 100644 index 00000000..a15537a0 --- /dev/null +++ b/deepspeech/training/extensions/snapshot.py @@ -0,0 +1,102 @@ +import os +from datetime import datetime +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines + +from deepspeech.training.updaters.trainer import Trainer +from deepspeech.training.extensions import extension +from deepspeech.utils.mp_tools import rank_zero_only + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +def load_records(records_fp): + """Load record files (json lines.)""" + with jsonlines.open(records_fp, 'r') as reader: + records = list(reader) + return records + + +class Snapshot(extension.Extension): + """An extension to make snapshot of the updater object inside + the trainer. It is done by calling the updater's `save` method. + An Updater save its state_dict by default, which contains the + updater state, (i.e. epoch and iteration) and all the model + parameters and optimizer states. If the updater inside the trainer + subclasses StandardUpdater, everything is good to go. + Parameters + ---------- + checkpoint_dir : Union[str, Path] + The directory to save checkpoints into. + """ + + trigger = (1, 'epoch') + priority = -100 + default_name = "snapshot" + + def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + self.records: List[Dict[str, Any]] = [] + self.max_size = max_size + self._snapshot_on_error = snapshot_on_error + self._save_all = (max_size == -1) + self.checkpoint_dir = None + + def initialize(self, trainer: Trainer): + """Setting up this extention.""" + self.checkpoint_dir = trainer.out / "checkpoints" + + # load existing records + record_path: Path = self.checkpoint_dir / "records.jsonl" + if record_path.exists(): + logger.debug("Loading from an existing checkpoint dir") + self.records = load_records(record_path) + trainer.updater.load(self.records[-1]['path']) + + def on_error(self, trainer, exc, tb): + if self._snapshot_on_error: + self.save_checkpoint_and_update(trainer) + + def __call__(self, trainer: Trainer): + self.save_checkpoint_and_update(trainer) + + def full(self): + """Whether the number of snapshots it keeps track of is greater + than the max_size.""" + return (not self._save_all) and len(self.records) > self.max_size + + @rank_zero_only + def save_checkpoint_and_update(self, trainer: Trainer): + """Saving new snapshot and remove the oldest snapshot if needed.""" + iteration = trainer.updater.state.iteration + epoch = trainer.updater.state.epoch + num = epoch if self.trigger[1] is 'epoch' else iteration + path = self.checkpoint_dir / f"{num}.pdz" + + # add the new one + trainer.updater.save(path) + record = { + "time": str(datetime.now()), + 'path': str(path.resolve()), # use absolute path + 'iteration': iteration, + 'epoch': epoch, + } + self.records.append(record) + + # remove the earist + if self.full(): + eariest_record = self.records[0] + os.remove(eariest_record["path"]) + self.records.pop(0) + + # update the record file + record_path = self.checkpoint_dir / "records.jsonl" + with jsonlines.open(record_path, 'w') as writer: + for record in self.records: + # jsonlines.open may return a Writer or a Reader + writer.write(record) # pylint: disable=no-member \ No newline at end of file diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py new file mode 100644 index 00000000..92e07704 --- /dev/null +++ b/deepspeech/training/extensions/visualizer.py @@ -0,0 +1,24 @@ +from deepspeech.training.extensions import extension +from deepspeech.training.updaters.trainer import Trainer + + +class VisualDL(extension.Extension): + """A wrapper of visualdl log writer. It assumes that the metrics to be visualized + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log + writer uses the iteration from the updater's `iteration` as the global step to + add records. + """ + trigger = (1, 'iteration') + default_name = 'visualdl' + priority = extension.PRIORITY_READER + + def __init__(self, writer): + self.writer = writer + + def __call__(self, trainer: Trainer): + for k, v in trainer.observation.items(): + self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) + + def finalize(self, trainer): + self.writer.close() \ No newline at end of file diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py new file mode 100644 index 00000000..a5f79fb0 --- /dev/null +++ b/deepspeech/training/reporter.py @@ -0,0 +1,131 @@ +import contextlib +import math +from collections import defaultdict + +OBSERVATIONS = None + + +@contextlib.contextmanager +def scope(observations): + # make `observation` the target to report to. + # it is basically a dictionary that stores temporary observations + global OBSERVATIONS + old = OBSERVATIONS + OBSERVATIONS = observations + + try: + yield + finally: + OBSERVATIONS = old + + +def get_observations(): + global OBSERVATIONS + return OBSERVATIONS + + +def report(name, value): + # a simple function to report named value + # you can use it everywhere, it will get the default target and writ to it + # you can think of it as std.out + observations = get_observations() + if observations is None: + return + else: + observations[name] = value + + +class Summary(): + """Online summarization of a sequence of scalars. + Summary computes the statistics of given scalars online. + """ + + def __init__(self): + self._x = 0.0 + self._x2 = 0.0 + self._n = 0 + + def add(self, value, weight=1): + """Adds a scalar value. + Args: + value: Scalar value to accumulate. It is either a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + weight: An optional weight for the value. It is a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + Default is 1 (integer). + """ + self._x += weight * value + self._x2 += weight * value * value + self._n += weight + + def compute_mean(self): + """Computes the mean.""" + x, n = self._x, self._n + return x / n + + def make_statistics(self): + """Computes and returns the mean and standard deviation values. + Returns: + tuple: Mean and standard deviation values. + """ + x, n = self._x, self._n + mean = x / n + var = self._x2 / n - mean * mean + std = math.sqrt(var) + return mean, std + + +class DictSummary(): + """Online summarization of a sequence of dictionaries. + ``DictSummary`` computes the statistics of a given set of scalars online. + It only computes the statistics for scalar values and variables of scalar + values in the dictionaries. + """ + + def __init__(self): + self._summaries = defaultdict(Summary) + + def add(self, d): + """Adds a dictionary of scalars. + Args: + d (dict): Dictionary of scalars to accumulate. Only elements of + scalars, zero-dimensional arrays, and variables of + zero-dimensional arrays are accumulated. When the value + is a tuple, the second element is interpreted as a weight. + """ + summaries = self._summaries + for k, v in d.items(): + w = 1 + if isinstance(v, tuple): + v = v[0] + w = v[1] + summaries[k].add(v, weight=w) + + def compute_mean(self): + """Creates a dictionary of mean values. + It returns a single dictionary that holds a mean value for each entry + added to the summary. + Returns: + dict: Dictionary of mean values. + """ + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } + + def make_statistics(self): + """Creates a dictionary of statistics. + It returns a single dictionary that holds mean and standard deviation + values for every entry added to the summary. For an entry of name + ``'key'``, these values are added to the dictionary by names ``'key'`` + and ``'key.std'``, respectively. + Returns: + dict: Dictionary of statistics of all entries. + """ + stats = {} + for name, summary in self._summaries.items(): + mean, std = summary.make_statistics() + stats[name] = mean + stats[name + '.std'] = std + + return stats \ No newline at end of file diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py new file mode 100644 index 00000000..9da7e615 --- /dev/null +++ b/deepspeech/training/triggers/__init__.py @@ -0,0 +1,13 @@ +from .interval_trigger import IntervalTrigger + +def never_fail_trigger(trainer): + return False + +def get_trigger(trigger): + if trigger is None: + return never_fail_trigger + if callable(trigger): + return trigger + else: + trigger = IntervalTrigger(*trigger) + return trigger \ No newline at end of file diff --git a/deepspeech/training/triggers/interval_trigger.py b/deepspeech/training/triggers/interval_trigger.py new file mode 100644 index 00000000..ef80379c --- /dev/null +++ b/deepspeech/training/triggers/interval_trigger.py @@ -0,0 +1,24 @@ + +class IntervalTrigger(): + """A Predicate to do something every N cycle.""" + + def __init__(self, period: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if period <= 0: + raise ValueError("period should be a positive integer.") + self.period = period + self.unit = unit + self.last_index = None + + def __call__(self, trainer): + if self.last_index is None: + last_index = getattr(trainer.updater.state, self.unit) + self.last_index = last_index + + last_index = self.last_index + index = getattr(trainer.updater.state, self.unit) + fire = index // self.period != last_index // self.period + + self.last_index = index + return fire \ No newline at end of file diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py new file mode 100644 index 00000000..ce13f940 --- /dev/null +++ b/deepspeech/training/triggers/limit_trigger.py @@ -0,0 +1,17 @@ + +class LimitTrigger(): + """A Predicate to decide whether to stop.""" + + def __init__(self, limit: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if limit <= 0: + raise ValueError("limit should be a positive integer.") + self.limit = limit + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index >= self.limit + return fire \ No newline at end of file diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py new file mode 100644 index 00000000..6232a12d --- /dev/null +++ b/deepspeech/training/triggers/time_trigger.py @@ -0,0 +1,17 @@ +class TimeTrigger(): + """Trigger based on a fixed time interval. + This trigger accepts iterations with a given interval time. + Args: + period (float): Interval time. It is given in seconds. + """ + + def __init__(self, period): + self._period = period + self._next_time = self._period + + def __call__(self, trainer): + if self._next_time < trainer.elapsed_time: + self._next_time += self._period + return True + else: + return False \ No newline at end of file diff --git a/deepspeech/training/updaters/__init__.py b/deepspeech/training/updaters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py new file mode 100644 index 00000000..062029ff --- /dev/null +++ b/deepspeech/training/updaters/standard_updater.py @@ -0,0 +1,179 @@ +from typing import Dict +from typing import Optional + +from paddle import Tensor +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from timer import timer + +from deepspeech.training.reporter import report +from deepspeech.training.updaters.updater import UpdaterBase +from deepspeech.training.updaters.updater import UpdaterState + +from deepspeech.utils.log import Log + +__all__ = ["StandardUpdater"] + +logger = Log(__name__).getlog() + +class StandardUpdater(UpdaterBase): + """An example of over-simplification. Things may not be that simple, but + you can subclass it to fit your need. + """ + + def __init__(self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state: Optional[UpdaterState]=None): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # dataloaders + self.dataloader = dataloader + + # init state + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + self.train_iterator = iter(dataloader) + + def update(self): + # We increase the iteration index after updating and before extension. + # Here are the reasons. + + # 0. Snapshotting(as well as other extensions, like visualizer) is + # executed after a step of updating; + # 1. We decide to increase the iteration index after updating and + # before any all extension is executed. + # 3. We do not increase the iteration after extension because we + # prefer a consistent resume behavior, when load from a + # `snapshot_iter_100.pdz` then the next step to train is `101`, + # naturally. But if iteration is increased increased after + # extension(including snapshot), then, a `snapshot_iter_99` is + # loaded. You would need a extra increasing of the iteration idex + # before training to avoid another iteration `99`, which has been + # done before snapshotting. + # 4. Thus iteration index represrnts "currently how mant epochs has + # been done." + # NOTE: use report to capture the correctly value. If you want to + # report the learning rate used for a step, you must report it before + # the learning rate scheduler's step() has been called. In paddle's + # convention, we do not use an extension to change the learning rate. + # so if you want to report it, do it in the updater. + + # Then here comes the next question. When is the proper time to + # increase the epoch index? Since all extensions are executed after + # updating, it is the time that after updating is the proper time to + # increase epoch index. + # 1. If we increase the epoch index before updating, then an extension + # based ot epoch would miss the correct timing. It could only be + # triggerd after an extra updating. + # 2. Theoretically, when an epoch is done, the epoch index should be + # increased. So it would be increase after updating. + # 3. Thus, eppoch index represents "currently how many epochs has been + # done." So it starts from 0. + + # switch to training mode + for model in self.models.values(): + model.train() + + # training for a step is implemented here + batch = self.read_batch() + self.update_core(batch) + + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_gradient() + loss_dict["main"].backward() + self.optimizer.update() + + @property + def updates_per_epoch(self): + """Number of updater per epoch, determined by the length of the + dataloader.""" + length_of_dataloader = None + try: + length_of_dataloader = len(self.dataloader) + except TypeError: + logger.debug("This dataloader has no __len__.") + finally: + return length_of_dataloader + + def new_epoch(self): + """Start a new epoch.""" + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + if hasattr(self.dataloader, "batch_sampler") + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) + + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + with timer() as t: + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) + logger.debug( + f"Read a batch takes {t.elapse}s.") # replace it with logger + return batch + + def state_dict(self): + """State dict of a Updater, model, optimizer and updater state are included.""" + state_dict = super().state_dict() + for name, model in self.models.items(): + state_dict[f"{name}_params"] = model.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers and UpdaterState are restored.""" + for name, model in self.models.items(): + model.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) \ No newline at end of file diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py new file mode 100644 index 00000000..c7562ff0 --- /dev/null +++ b/deepspeech/training/updaters/trainer.py @@ -0,0 +1,171 @@ +import sys +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Callable +from typing import List +from typing import Union + +import six +import tqdm + +from deepspeech.training.extensions.extension import Extension +from deepspeech.training.extensions.extension import PRIORITY_READER +from deepspeech.training.reporter import scope +from deepspeech.training.triggers import get_trigger +from deepspeech.training.triggers.limit_trigger import LimitTrigger +from deepspeech.training.updaters.updater import UpdaterBase + + +class _ExtensionEntry(): + def __init__(self, extension, trigger, priority): + self.extension = extension + self.trigger = trigger + self.priority = priority + + +class Trainer(): + def __init__(self, + updater: UpdaterBase, + stop_trigger: Callable=None, + out: Union[str, Path]='result', + extensions: List[Extension]=None): + self.updater = updater + self.extensions = OrderedDict() + self.stop_trigger = LimitTrigger(*stop_trigger) + self.out = Path(out) + self.observation = None + + self._done = False + if extensions: + for ext in extensions: + self.extend(ext) + + @property + def is_before_training(self): + return self.updater.state.iteration == 0 + + def extend(self, extension, name=None, trigger=None, priority=None): + # get name for the extension + # argument \ + # -> extention's name \ + # -> default_name (class name, when it is an object) \ + # -> function name when it is a function \ + # -> error + + if name is None: + name = getattr(extension, 'name', None) + if name is None: + name = getattr(extension, 'default_name', None) + if name is None: + name = getattr(extension, '__name__', None) + if name is None: + raise ValueError("Name is not given for the extension.") + if name == 'training': + raise ValueError("training is a reserved name.") + + if trigger is None: + trigger = getattr(extension, 'trigger', (1, 'iteration')) + trigger = get_trigger(trigger) + + if priority is None: + priority = getattr(extension, 'priority', PRIORITY_READER) + + # add suffix to avoid nameing conflict + ordinal = 0 + modified_name = name + while modified_name in self.extensions: + ordinal += 1 + modified_name = f"{name}_{ordinal}" + extension.name = modified_name + + self.extensions[modified_name] = _ExtensionEntry(extension, trigger, + priority) + + def get_extension(self, name): + """get extension by name.""" + extensions = self.extensions + if name in extensions: + return extensions[name].extension + else: + raise ValueError(f'extension {name} not found') + + def run(self): + if self._done: + raise RuntimeError("Training is already done!.") + + self.out.mkdir(parents=True, exist_ok=True) + + # sort extensions by priorities once + extension_order = sorted( + self.extensions.keys(), + key=lambda name: self.extensions[name].priority, + reverse=True) + extensions = [(name, self.extensions[name]) for name in extension_order] + + # initializing all extensions + for name, entry in extensions: + if hasattr(entry.extension, "initialize"): + entry.extension.initialize(self) + + update = self.updater.update # training step + stop_trigger = self.stop_trigger + + # display only one progress bar + max_iteration = None + if isinstance(stop_trigger, LimitTrigger): + if stop_trigger.unit == 'epoch': + max_epoch = self.stop_trigger.limit + updates_per_epoch = getattr(self.updater, "updates_per_epoch", + None) + max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None + else: + max_iteration = self.stop_trigger.limit + + p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration) + + try: + while not stop_trigger(self): + self.observation = {} + # set observation as the report target + # you can use report freely in Updater.update() + + # updating parameters and state + with scope(self.observation): + update() + p.update() + + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) + + # print("###", self.observation) + except Exception as e: + f = sys.stderr + f.write(f"Exception in main training loop: {e}\n") + f.write("Traceback (most recent call last):\n") + traceback.print_tb(sys.exc_info()[2]) + f.write( + "Trainer extensions will try to handle the extension. Then all extensions will finalize." + ) + + # capture the exception in the mian training loop + exc_info = sys.exc_info() + + # try to handle it + for name, entry in extensions: + if hasattr(entry.extension, "on_error"): + try: + entry.extension.on_error(self, e, sys.exc_info()[2]) + except Exception as ee: + f.write(f"Exception in error handler: {ee}\n") + f.write('Traceback (most recent call last):\n') + traceback.print_tb(sys.exc_info()[2]) + + # raise exception in main training loop + six.reraise(*exc_info) + finally: + for name, entry in extensions: + if hasattr(entry.extension, "finalize"): + entry.extension.finalize(self) \ No newline at end of file diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py new file mode 100644 index 00000000..548042d6 --- /dev/null +++ b/deepspeech/training/updaters/updater.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import paddle + +from deepspeech.utils.log import Log + +__all__ = ["UpdaterBase", "UpdaterState"] + +logger = Log(__name__).getlog() + + +@dataclass +class UpdaterState: + iteration: int = 0 + epoch: int = 0 + + +class UpdaterBase(): + """An updater is the abstraction of how a model is trained given the + dataloader and the optimizer. + The `update_core` method is a step in the training loop with only necessary + operations (get a batch, forward and backward, update the parameters). + Other stuffs are made extensions. Visualization, saving, loading and + periodical validation and evaluation are not considered here. + But even in such simplist case, things are not that simple. There is an + attempt to standardize this process and requires only the model and + dataset and do all the stuffs automatically. But this may hurt flexibility. + If we assume a batch yield from the dataloader is just the input to the + model, we will find that some model requires more arguments, or just some + keyword arguments. But this prevents us from over-simplifying it. + From another perspective, the batch may includes not just the input, but + also the target. But the model's forward method may just need the input. + We can pass a dict or a super-long tuple to the model and let it pick what + it really needs. But this is an abuse of lazy interface. + After all, we care about how a model is trained. But just how the model is + used for inference. We want to control how a model is trained. We just + don't want to be messed up with other auxiliary code. + So the best practice is to define a model and define a updater for it. + """ + + def __init__(self, init_state=None): + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + def update(self, batch): + raise NotImplementedError( + "Implement your own `update` method for training a step.") + + def state_dict(self): + state_dict = { + "epoch": self.state.epoch, + "iteration": self.state.iteration, + } + return state_dict + + def set_state_dict(self, state_dict): + self.state.epoch = state_dict["epoch"] + self.state.iteration = state_dict["iteration"] + + def save(self, path): + logger.debug(f"Saving to {path}.") + archive = self.state_dict() + paddle.save(archive, str(path)) + + def load(self, path): + logger.debug(f"Loading from {path}.") + archive = paddle.load(str(path)) + self.set_state_dict(archive) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 08f2f258..1ed5525e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ tensorboardX textgrid typeguard yacs +jsonlines \ No newline at end of file From 14ac7806584ac3531304f3288580636631a0aa13 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 25 Aug 2021 02:22:48 +0000 Subject: [PATCH 16/17] fix trainer when dataloader not using batch_sampler --- deepspeech/training/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2ab7eac0..866be552 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -181,7 +181,7 @@ class Trainer(): """Reset the train loader seed and increment `epoch`. """ self.epoch += 1 - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) def train(self): @@ -191,7 +191,7 @@ class Trainer(): # save init model, i.e. 0 epoch self.save(tag='init', infos=None) self.lr_scheduler.step(self.epoch) - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") From 673cc4a081bd04dd7136aac99d36aa93678e3410 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 25 Aug 2021 11:24:48 +0000 Subject: [PATCH 17/17] seed all with log; and format --- deepspeech/training/cli.py | 2 +- deepspeech/training/extensions/__init__.py | 17 ++++++++++++-- deepspeech/training/extensions/evaluator.py | 17 ++++++++++++-- deepspeech/training/extensions/extension.py | 17 +++++++++++--- deepspeech/training/extensions/snapshot.py | 22 ++++++++++++++----- deepspeech/training/extensions/visualizer.py | 15 ++++++++++++- deepspeech/training/reporter.py | 15 ++++++++++++- deepspeech/training/trainer.py | 16 ++++++-------- deepspeech/training/triggers/__init__.py | 17 +++++++++++++- .../training/triggers/interval_trigger.py | 16 +++++++++++++- deepspeech/training/triggers/limit_trigger.py | 16 +++++++++++++- deepspeech/training/triggers/time_trigger.py | 17 +++++++++++++- deepspeech/training/updaters/__init__.py | 13 +++++++++++ .../training/updaters/standard_updater.py | 19 +++++++++++++--- deepspeech/training/updaters/trainer.py | 15 ++++++++++++- deepspeech/training/updaters/updater.py | 3 ++- deepspeech/utils/utility.py | 12 +++++++++- examples/aishell/s0/README.md | 4 +--- requirements.txt | 2 +- 19 files changed, 217 insertions(+), 38 deletions(-) diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index ecd7a8f2..7f4bb804 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -64,7 +64,7 @@ def default_argument_parser(): help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") parser.add_argument("--seed", type=int, default=None, - help="seed to use for paddle, np and random. The default value is None") + help="seed to use for paddle, np and random. None or 0 for random, else set seed.") # yapd: enable return parser diff --git a/deepspeech/training/extensions/__init__.py b/deepspeech/training/extensions/__init__.py index 7ea7470e..6ad04155 100644 --- a/deepspeech/training/extensions/__init__.py +++ b/deepspeech/training/extensions/__init__.py @@ -1,8 +1,21 @@ - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable from .extension import Extension + def make_extension(trigger: Callable=None, default_name: str=None, priority: int=None, @@ -25,4 +38,4 @@ def make_extension(trigger: Callable=None, ext.initialize = initializer return ext - return decorator \ No newline at end of file + return decorator diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index ffb7b3a2..96ff967f 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -1,10 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Dict +import extension import paddle from paddle.io import DataLoader from paddle.nn import Layer -import extension from ..reporter import DictSummary from ..reporter import report from ..reporter import scope @@ -55,4 +68,4 @@ class StandardEvaluator(extension.Extension): # or otherwise, you can use your own observation summary = self.evaluate() for k, v in summary.items(): - report(k, v) \ No newline at end of file + report(k, v) diff --git a/deepspeech/training/extensions/extension.py b/deepspeech/training/extensions/extension.py index f8fcede3..02f92495 100644 --- a/deepspeech/training/extensions/extension.py +++ b/deepspeech/training/extensions/extension.py @@ -1,5 +1,16 @@ -from typing import Callable - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. PRIORITY_WRITER = 300 PRIORITY_EDITOR = 200 PRIORITY_READER = 100 @@ -38,4 +49,4 @@ class Extension(): """Action that is executed when training is done. For example, visualizers would need to be closed. """ - pass \ No newline at end of file + pass diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index a15537a0..cb4e6dfb 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from datetime import datetime from pathlib import Path @@ -7,11 +20,10 @@ from typing import List import jsonlines -from deepspeech.training.updaters.trainer import Trainer from deepspeech.training.extensions import extension -from deepspeech.utils.mp_tools import rank_zero_only - +from deepspeech.training.updaters.trainer import Trainer from deepspeech.utils.log import Log +from deepspeech.utils.mp_tools import rank_zero_only logger = Log(__name__).getlog() @@ -75,7 +87,7 @@ class Snapshot(extension.Extension): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch - num = epoch if self.trigger[1] is 'epoch' else iteration + num = epoch if self.trigger[1] == 'epoch' else iteration path = self.checkpoint_dir / f"{num}.pdz" # add the new one @@ -99,4 +111,4 @@ class Snapshot(extension.Extension): with jsonlines.open(record_path, 'w') as writer: for record in self.records: # jsonlines.open may return a Writer or a Reader - writer.write(record) # pylint: disable=no-member \ No newline at end of file + writer.write(record) # pylint: disable=no-member diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index 92e07704..b69e94aa 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from deepspeech.training.extensions import extension from deepspeech.training.updaters.trainer import Trainer @@ -21,4 +34,4 @@ class VisualDL(extension.Extension): self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) def finalize(self, trainer): - self.writer.close() \ No newline at end of file + self.writer.close() diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py index a5f79fb0..66a81ade 100644 --- a/deepspeech/training/reporter.py +++ b/deepspeech/training/reporter.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import contextlib import math from collections import defaultdict @@ -128,4 +141,4 @@ class DictSummary(): stats[name] = mean stats[name + '.std'] = std - return stats \ No newline at end of file + return stats diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 866be552..3a922c6f 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random import time from pathlib import Path -import numpy as np import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter @@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter from deepspeech.utils import mp_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log +from deepspeech.utils.utility import seed_all __all__ = ["Trainer"] @@ -95,13 +94,10 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 - if args.seed is not None: - self.set_seed(args.seed) - def set_seed(self, seed): - np.random.seed(seed) - random.seed(seed) - paddle.seed(seed) + if args.seed: + seed_all(args.seed) + logger.info(f"Set seed {args.seed}") def setup(self): """Setup the experiment. @@ -182,7 +178,9 @@ class Trainer(): """ self.epoch += 1 if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + batch_sampler = self.train_loader.batch_sampler + if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): + batch_sampler.set_epoch(self.epoch) def train(self): """The training process control by epoch.""" diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py index 9da7e615..1a7c4292 100644 --- a/deepspeech/training/triggers/__init__.py +++ b/deepspeech/training/triggers/__init__.py @@ -1,8 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .interval_trigger import IntervalTrigger + def never_fail_trigger(trainer): return False + def get_trigger(trigger): if trigger is None: return never_fail_trigger @@ -10,4 +25,4 @@ def get_trigger(trigger): return trigger else: trigger = IntervalTrigger(*trigger) - return trigger \ No newline at end of file + return trigger diff --git a/deepspeech/training/triggers/interval_trigger.py b/deepspeech/training/triggers/interval_trigger.py index ef80379c..1e04afad 100644 --- a/deepspeech/training/triggers/interval_trigger.py +++ b/deepspeech/training/triggers/interval_trigger.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + class IntervalTrigger(): """A Predicate to do something every N cycle.""" @@ -21,4 +35,4 @@ class IntervalTrigger(): fire = index // self.period != last_index // self.period self.last_index = index - return fire \ No newline at end of file + return fire diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py index ce13f940..ecd527ac 100644 --- a/deepspeech/training/triggers/limit_trigger.py +++ b/deepspeech/training/triggers/limit_trigger.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + class LimitTrigger(): """A Predicate to decide whether to stop.""" @@ -14,4 +28,4 @@ class LimitTrigger(): state = trainer.updater.state index = getattr(state, self.unit) fire = index >= self.limit - return fire \ No newline at end of file + return fire diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py index 6232a12d..ea8fe562 100644 --- a/deepspeech/training/triggers/time_trigger.py +++ b/deepspeech/training/triggers/time_trigger.py @@ -1,3 +1,18 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + class TimeTrigger(): """Trigger based on a fixed time interval. This trigger accepts iterations with a given interval time. @@ -14,4 +29,4 @@ class TimeTrigger(): self._next_time += self._period return True else: - return False \ No newline at end of file + return False diff --git a/deepspeech/training/updaters/__init__.py b/deepspeech/training/updaters/__init__.py index e69de29b..185a92b8 100644 --- a/deepspeech/training/updaters/__init__.py +++ b/deepspeech/training/updaters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index 062029ff..fc758e93 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Dict from typing import Optional @@ -11,13 +24,13 @@ from timer import timer from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterState - from deepspeech.utils.log import Log __all__ = ["StandardUpdater"] logger = Log(__name__).getlog() + class StandardUpdater(UpdaterBase): """An example of over-simplification. Things may not be that simple, but you can subclass it to fit your need. @@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase): """Start a new epoch.""" # NOTE: all batch sampler for distributed training should # subclass DistributedBatchSampler and implement `set_epoch` method - if hasattr(self.dataloader, "batch_sampler") + if hasattr(self.dataloader, "batch_sampler"): batch_sampler = self.dataloader.batch_sampler if isinstance(batch_sampler, DistributedBatchSampler): batch_sampler.set_epoch(self.state.epoch) @@ -176,4 +189,4 @@ class StandardUpdater(UpdaterBase): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): optim.set_state_dict(state_dict[f"{name}_optimizer"]) - super().set_state_dict(state_dict) \ No newline at end of file + super().set_state_dict(state_dict) diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index c7562ff0..954ce260 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys import traceback from collections import OrderedDict @@ -168,4 +181,4 @@ class Trainer(): finally: for name, entry in extensions: if hasattr(entry.extension, "finalize"): - entry.extension.finalize(self) \ No newline at end of file + entry.extension.finalize(self) diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 548042d6..66fdc2bb 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass + import paddle from deepspeech.utils.log import Log @@ -79,4 +80,4 @@ class UpdaterBase(): def load(self, path): logger.debug(f"Loading from {path}.") archive = paddle.load(str(path)) - self.set_state_dict(archive) \ No newline at end of file + self.set_state_dict(archive) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index a0639e06..e18fc1f7 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -15,9 +15,19 @@ import distutils.util import math import os +import random from typing import List -__all__ = ['print_arguments', 'add_arguments', "log_add"] +import numpy as np +import paddle + +__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"] + + +def seed_all(seed: int=210329): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def print_arguments(args, info=None): diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index eedf92c9..537496a6 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -3,11 +3,9 @@ ## Data | Data Subset | Duration in Seconds | | data/manifest.train | 1.23 ~ 14.53125 | -| data/manifest.dev | 1.645 ~ 12.533 | +| data/manifest.dev | 1.645 ~ 12.533 | | data/manifest.test | 1.859125 ~ 14.6999375 | -`jq '.feat_shape[0]' data/manifest.train | sort -un` - ## Deepspeech2 | Model | Params | Release | Config | Test set | Loss | CER | diff --git a/requirements.txt b/requirements.txt index 1ed5525e..7c3da37e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ coverage gpustat +jsonlines kaldiio Pillow pre-commit @@ -15,4 +16,3 @@ tensorboardX textgrid typeguard yacs -jsonlines \ No newline at end of file