From 907c93392faa6a3d0695adcf95a99fbe0bc33a61 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 6 Apr 2021 09:22:41 +0000 Subject: [PATCH] refactor cmvn, test --- deepspeech/__init__.py | 71 +----------- deepspeech/frontend/normalizer.py | 18 +-- deepspeech/frontend/utility.py | 108 +++++++++++++++++- deepspeech/models/u2.py | 9 +- deepspeech/modules/mask.py | 8 +- deepspeech/utils/cmvn.py | 93 --------------- deepspeech/utils/tensor_utils.py | 79 +++++++++---- examples/tiny/README.md | 1 + tests/deepspeech2_model_test.py | 105 +++++++++++++++++ ...{test_error_rate.py => error_rate_test.py} | 0 tests/mask_test.py | 39 +++++++ tests/network_test.py | 99 ---------------- tests/u2_model_test.py | 60 ++++++++++ 13 files changed, 391 insertions(+), 299 deletions(-) delete mode 100644 deepspeech/utils/cmvn.py create mode 100644 tests/deepspeech2_model_test.py rename tests/{test_error_rate.py => error_rate_test.py} (100%) create mode 100644 tests/mask_test.py delete mode 100644 tests/network_test.py create mode 100644 tests/u2_model_test.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5d840148a..ab5f0e137 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -414,73 +414,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'): if not hasattr(paddle.jit, 'export'): logger.warn("register user export to paddle.jit, remove this when fixed!") - setattr(paddle.jit, 'export', paddle.jit.to_static) - - -########### hcak paddle.nn.utils ############# -def pad_sequence(sequences: List[paddle.Tensor], - batch_first: bool=False, - padding_value: float=0.0) -> paddle.Tensor: - r"""Pad a list of variable length Tensors with ``padding_value`` - - ``pad_sequence`` stacks a list of Tensors along a new dimension, - and pads them to equal length. For example, if the input is list of - sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` - otherwise. - - `B` is batch size. It is equal to the number of elements in ``sequences``. - `T` is length of the longest sequence. - `L` is length of the sequence. - `*` is any number of trailing dimensions, including none. - - Example: - >>> from paddle.nn.utils.rnn import pad_sequence - >>> a = paddle.ones(25, 300) - >>> b = paddle.ones(22, 300) - >>> c = paddle.ones(15, 300) - >>> pad_sequence([a, b, c]).size() - paddle.Tensor([25, 3, 300]) - - Note: - This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` - where `T` is the length of the longest sequence. This function assumes - trailing dimensions and type of all the Tensors in sequences are same. - - Args: - sequences (list[Tensor]): list of variable length sequences. - batch_first (bool, optional): output will be in ``B x T x *`` if True, or in - ``T x B x *`` otherwise - padding_value (float, optional): value for padded elements. Default: 0. - - Returns: - Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. - Tensor of size ``B x T x *`` otherwise - """ - - # assuming trailing dimensions and type of all the Tensors - # in sequences are same and fetching those from sequences[0] - max_size = sequences[0].size() - trailing_dims = max_size[1:] - max_len = max([s.size(0) for s in sequences]) - if batch_first: - out_dims = (len(sequences), max_len) + trailing_dims - else: - out_dims = (max_len, len(sequences)) + trailing_dims - - out_tensor = sequences[0].new_full(out_dims, padding_value) - for i, tensor in enumerate(sequences): - length = tensor.size(0) - # use index notation to prevent duplicate references to the tensor - if batch_first: - out_tensor[i, :length, ...] = tensor - else: - out_tensor[:length, i, ...] = tensor - - return out_tensor - - -if not hasattr(paddle.nn.utils, 'rnn.pad_sequence'): - logger.warn( - "register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!" - ) - setattr(paddle.nn.utils, 'rnn.pad_sequence', pad_sequence) + setattr(paddle.jit, 'export', paddle.jit.to_static) \ No newline at end of file diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 8e50566c6..f8ee52f03 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -57,17 +57,17 @@ class FeatureNormalizer(object): else: self._read_mean_std_from_file(mean_std_filepath) - def apply(self, features, eps=1e-14): + def apply(self, features): """Normalize features to be of zero mean and unit stddev. :param features: Input features to be normalized. - :type features: ndarray + :type features: ndarray, shape (D, T) :param eps: added to stddev to provide numerical stablibity. :type eps: float :return: Normalized features. :rtype: ndarray """ - return (features - self._mean) / (self._std + eps) + return (features - self._mean) * self._istd def write_to_file(self, filepath): """Write the mean and stddev to the file. @@ -77,11 +77,13 @@ class FeatureNormalizer(object): """ np.savez(filepath, mean=self._mean, std=self._std) - def _read_mean_std_from_file(self, filepath): + def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" npzfile = np.load(filepath) self._mean = npzfile["mean"] - self._std = npzfile["std"] + std = npzfile["std"] + std = np.clip(std, eps, None) + self._istd = 1.0 / std def _compute_mean_std(self, manifest_path, featurize_func, num_samples): """Compute mean and std from randomly sampled instances.""" @@ -92,6 +94,6 @@ class FeatureNormalizer(object): features.append( featurize_func( AudioSegment.from_file(instance["audio_filepath"]))) - features = np.hstack(features) - self._mean = np.mean(features, axis=1).reshape([-1, 1]) - self._std = np.std(features, axis=1).reshape([-1, 1]) + features = np.hstack(features) #(D, T) + self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1) + self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index a9e0e5c51..de602cb97 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains data helper functions.""" - import numpy as np import math import json @@ -20,11 +19,19 @@ import codecs import os import tarfile import time +import logging from threading import Thread from multiprocessing import Process, Manager, Value from paddle.dataset.common import md5file +logger = logging.getLogger(__name__) + +__all__ = [ + "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", + "mean_dbfs", "gain_db_to_ratio", "normalize_audio" +] + def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): """Load and parse manifest file. @@ -134,3 +141,102 @@ def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): return np.maximum( np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0), -1.0) + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + Args: + json_cmvn_file: cmvn stats file in json format + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logger.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_npz_cmvn(npz_cmvn_file, eps=1e-20): + npzfile = np.load(npz_cmvn_file) + means = npzfile["mean"] + std = npzfile["std"] + std = np.clip(std, eps, None) + variance = 1.0 / std + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file: str, filetype: str): + """load cmvn from file. + + Args: + cmvn_file (str): cmvn path. + filetype (str): file type, optional[npz, json, kaldi]. + + Raises: + ValueError: file type not support. + + Returns: + Tuple[np.ndarray, np.ndarray]: mean, istd + """ + assert filetype in ['npz', 'json', 'kaldi'], filetype + filetype = filetype.lower() + if filetype == "json": + cmvn = _load_json_cmvn(cmvn_file) + elif filetype == "kaldi": + cmvn = _load_kaldi_cmvn(cmvn_file) + elif filtype == "npz": + cmvn = _load_npz_cmvn(cmvn_file) + else: + raise ValueError(f"cmvn file type no support: {filetype}") + return cmvn[0], cmvn[1] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index db923001e..9ecbc0177 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -29,7 +29,6 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I -from paddle.nn.utils.rnn import pad_sequence from deepspeech.modules.mask import make_pad_mask from deepspeech.modules.mask import mask_finished_preds from deepspeech.modules.mask import mask_finished_scores @@ -42,13 +41,15 @@ from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss +from deepspeech.frontend.utility import load_cmvn + from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools -from deepspeech.utils.cmvn import load_cmvn from deepspeech.utils.utility import log_add from deepspeech.utils.tensor_utils import IGNORE_ID from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.ctc_utils import remove_duplicates_and_blank logger = logging.getLogger(__name__) @@ -635,7 +636,7 @@ class U2TransformerModel(U2Model): def __init__(configs: dict): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], - configs['is_json_cmvn']) + configs['cmvn_file_type']) global_cmvn = GlobalCMVN( paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) else: @@ -666,7 +667,7 @@ class U2ConformerModel(U2Model): def __init__(configs: dict): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], - configs['is_json_cmvn']) + configs['cmvn_file_type']) global_cmvn = GlobalCMVN( paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) else: diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index e38d75f8f..4dd709aba 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -21,7 +21,11 @@ from paddle.nn import initializer as I logger = logging.getLogger(__name__) -__all__ = ['sequence_mask'] +__all__ = [ + 'sequence_mask', "make_pad_mask", "make_non_pad_mask", "subsequent_mask", + "subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores", + "mask_finished_preds" +] def sequence_mask(x_len, max_len=None, dtype='float32'): @@ -93,7 +97,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ - return ~make_pad_mask(lengths) + return make_pad_mask(lengths).logical_not() def subsequent_mask(size: int) -> paddle.Tensor: diff --git a/deepspeech/utils/cmvn.py b/deepspeech/utils/cmvn.py deleted file mode 100644 index 5c5573ee9..000000000 --- a/deepspeech/utils/cmvn.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import math -import logging -import numpy as np - -logger = logging.getLogger(__name__) - -__all__ = ['load_cmvn'] - - -def _load_json_cmvn(json_cmvn_file): - """ Load the json format cmvn stats file and calculate cmvn - Args: - json_cmvn_file: cmvn stats file in json format - Returns: - a numpy array of [means, vars] - """ - with open(json_cmvn_file) as f: - cmvn_stats = json.load(f) - - means = cmvn_stats['mean_stat'] - variance = cmvn_stats['var_stat'] - count = cmvn_stats['frame_num'] - for i in range(len(means)): - means[i] /= count - variance[i] = variance[i] / count - means[i] * means[i] - if variance[i] < 1.0e-20: - variance[i] = 1.0e-20 - variance[i] = 1.0 / math.sqrt(variance[i]) - cmvn = np.array([means, variance]) - return cmvn - - -def _load_kaldi_cmvn(kaldi_cmvn_file): - """ Load the kaldi format cmvn stats file and calculate cmvn - Args: - kaldi_cmvn_file: kaldi text style global cmvn file, which - is generated by: - compute-cmvn-stats --binary=false scp:feats.scp global_cmvn - Returns: - a numpy array of [means, vars] - """ - means = [] - variance = [] - with open(kaldi_cmvn_file, 'r') as fid: - # kaldi binary file start with '\0B' - if fid.read(2) == '\0B': - logger.error('kaldi cmvn binary file is not supported, please ' - 'recompute it by: compute-cmvn-stats --binary=false ' - ' scp:feats.scp global_cmvn') - sys.exit(1) - fid.seek(0) - arr = fid.read().split() - assert (arr[0] == '[') - assert (arr[-2] == '0') - assert (arr[-1] == ']') - feat_dim = int((len(arr) - 2 - 2) / 2) - for i in range(1, feat_dim + 1): - means.append(float(arr[i])) - count = float(arr[feat_dim + 1]) - for i in range(feat_dim + 2, 2 * feat_dim + 2): - variance.append(float(arr[i])) - - for i in range(len(means)): - means[i] /= count - variance[i] = variance[i] / count - means[i] * means[i] - if variance[i] < 1.0e-20: - variance[i] = 1.0e-20 - variance[i] = 1.0 / math.sqrt(variance[i]) - cmvn = np.array([means, variance]) - return cmvn - - -def load_cmvn(cmvn_file, is_json): - if is_json: - cmvn = _load_json_cmvn(cmvn_file) - else: - cmvn = _load_kaldi_cmvn(cmvn_file) - return cmvn[0], cmvn[1] diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 627f51630..9f67c1a61 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -20,35 +20,70 @@ import paddle logger = logging.getLogger(__name__) -__all__ = ["pad_list", "add_sos_eos", "th_accuracy"] +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] IGNORE_ID = -1 -def pad_list(xs: List[paddle.Tensor], pad_value: int): - """Perform padding for the list of tensors. +def pad_sequence(sequences: List[paddle.Tensor], + batch_first: bool=False, + padding_value: float=0.0) -> paddle.Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from paddle.nn.utils.rnn import pad_sequence + >>> a = paddle.ones(25, 300) + >>> b = paddle.ones(22, 300) + >>> c = paddle.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + paddle.Tensor([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + Args: - xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. - pad_value (float): Value for padding. + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + Returns: - Tensor: Padded tensor (B, Tmax, `*`). - Examples: - >>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise """ - n_batch = len(xs) - max_len = max([x.size(0) for x in xs]) - pad = paddle.zeros(n_batch, max_len, dtype=xs[0].dtype) - pad = pad.fill_(pad_value) - for i in range(n_batch): - pad[i, :xs[i].size(0)] = xs[i] - - return pad + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max([s.size(0) for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].new_full(out_dims, padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, diff --git a/examples/tiny/README.md b/examples/tiny/README.md index e109e1ae4..6766f59a2 100644 --- a/examples/tiny/README.md +++ b/examples/tiny/README.md @@ -1 +1,2 @@ * s0 for deepspeech2 +* s1 for U2 diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py new file mode 100644 index 000000000..8ada42c35 --- /dev/null +++ b/tests/deepspeech2_model_test.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from deepspeech.models.deepspeech2 import DeepSpeech2Model + + +class TestDeepSpeech2Model(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + + self.batch_size = 2 + self.feat_dim = 161 + max_len = 64 + + #(B, T, D) + audio = np.random.randn(self.batch_size, max_len, self.feat_dim) + audio_len = np.random.randint( + max_len, size=self.batch_size, dtype='int32') + audio_len[-1] = max_len + #(B, U) + text = np.array([[1, 2], [1, 2]], dtype='int32') + text_len = np.array([2] * self.batch_size, dtype='int32') + + self.audio = paddle.to_tensor(audio, dtype='float32') + self.audio_len = paddle.to_tensor(audio_len, dtype='int64') + self.text = paddle.to_tensor(text, dtype='int32') + self.text_len = paddle.to_tensor(text_len, dtype='int64') + + def test_ds2_1(self): + model = DeepSpeech2Model( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=False, ) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_2(self): + model = DeepSpeech2Model( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=True, + share_rnn_weights=False, ) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_3(self): + model = DeepSpeech2Model( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, ) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_4(self): + model = DeepSpeech2Model( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=True, + share_rnn_weights=True, ) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_5(self): + model = DeepSpeech2Model( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=False, ) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_error_rate.py b/tests/error_rate_test.py similarity index 100% rename from tests/test_error_rate.py rename to tests/error_rate_test.py diff --git a/tests/mask_test.py b/tests/mask_test.py new file mode 100644 index 000000000..f5e9cd7cb --- /dev/null +++ b/tests/mask_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from deepspeech.modules.mask import sequence_mask +from deepspeech.modules.mask import make_non_pad_mask + + +class TestU2Model(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + self.lengths = paddle.to_tensor([5, 3, 2]) + self.masks = np.array( + [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]], ) + + def test_sequence_mask(self): + res = sequence_mask(self.lengths) + self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) + + def test_make_non_pad_mask(self): + res = make_non_pad_mask(self.lengths) + self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/network_test.py b/tests/network_test.py deleted file mode 100644 index ae86c9c43..000000000 --- a/tests/network_test.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import numpy as np - -from deepspeech.models.deepspeech2 import DeepSpeech2Model - -if __name__ == '__main__': - batch_size = 2 - feat_dim = 161 - max_len = 100 - audio = np.random.randn(batch_size, feat_dim, max_len) - audio_len = np.random.randint(100, size=batch_size, dtype='int32') - audio_len[-1] = 100 - text = np.array([[1, 2], [1, 2]], dtype='int32') - text_len = np.array([2] * batch_size, dtype='int32') - - audio = paddle.to_tensor(audio, dtype='float32') - audio_len = paddle.to_tensor(audio_len, dtype='int64') - text = paddle.to_tensor(text, dtype='int32') - text_len = paddle.to_tensor(text_len, dtype='int64') - - print(audio.shape) - print(audio_len.shape) - print(text.shape) - print(text_len.shape) - print("-----------------") - - model = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=False, ) - logits, probs, logits_len = model(audio, audio_len, text, text_len) - print('probs.shape', probs.shape) - print("-----------------") - - model2 = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=True, - share_rnn_weights=False, ) - logits, probs, logits_len = model2(audio, audio_len, text, text_len) - print('probs.shape', probs.shape) - print("-----------------") - - model3 = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, ) - logits, probs, logits_len = model3(audio, audio_len, text, text_len) - print('probs.shape', probs.shape) - print("-----------------") - - model4 = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=True, - share_rnn_weights=True, ) - logits, probs, logits_len = model4(audio, audio_len, text, text_len) - print('probs.shape', probs.shape) - print("-----------------") - - model5 = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=False, ) - logits, probs, logits_len = model5(audio, audio_len, text, text_len) - print('probs.shape', probs.shape) - print("-----------------") diff --git a/tests/u2_model_test.py b/tests/u2_model_test.py new file mode 100644 index 000000000..e2230a394 --- /dev/null +++ b/tests/u2_model_test.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from deepspeech.models.u2 import U2TransformerModel +from deepspeech.models.u2 import U2ConformerModel + + +class TestU2Model(unittest.TestCase): + def setUp(self): + batch_size = 2 + feat_dim = 161 + max_len = 100 + audio = np.random.randn(batch_size, feat_dim, max_len) + audio_len = np.random.randint(100, size=batch_size, dtype='int32') + audio_len[-1] = 100 + text = np.array([[1, 2], [1, 2]], dtype='int32') + text_len = np.array([2] * batch_size, dtype='int32') + + self.audio = paddle.to_tensor(audio, dtype='float32') + self.audio_len = paddle.to_tensor(audio_len, dtype='int64') + self.text = paddle.to_tensor(text, dtype='int32') + self.text_len = paddle.to_tensor(text_len, dtype='int64') + + print(audio.shape) + print(audio_len.shape) + print(text.shape) + print(text_len.shape) + print("-----------------") + + def test_ds2_1(self): + model = DeepSpeech2Model( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=False, ) + logits, probs, logits_len = model(self.audio, self.audio_len, self.text, + self.text_len) + print('probs.shape', probs.shape) + print("-----------------") + + +if __name__ == '__main__': + unittest.main()