refactor cmvn, test

pull/578/head
Hui Zhang 5 years ago
parent 5659bd2386
commit 907c93392f

@ -414,73 +414,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
if not hasattr(paddle.jit, 'export'): if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!") logger.warn("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static) setattr(paddle.jit, 'export', paddle.jit.to_static)
########### hcak paddle.nn.utils #############
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
if not hasattr(paddle.nn.utils, 'rnn.pad_sequence'):
logger.warn(
"register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!"
)
setattr(paddle.nn.utils, 'rnn.pad_sequence', pad_sequence)

@ -57,17 +57,17 @@ class FeatureNormalizer(object):
else: else:
self._read_mean_std_from_file(mean_std_filepath) self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-14): def apply(self, features):
"""Normalize features to be of zero mean and unit stddev. """Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized. :param features: Input features to be normalized.
:type features: ndarray :type features: ndarray, shape (D, T)
:param eps: added to stddev to provide numerical stablibity. :param eps: added to stddev to provide numerical stablibity.
:type eps: float :type eps: float
:return: Normalized features. :return: Normalized features.
:rtype: ndarray :rtype: ndarray
""" """
return (features - self._mean) / (self._std + eps) return (features - self._mean) * self._istd
def write_to_file(self, filepath): def write_to_file(self, filepath):
"""Write the mean and stddev to the file. """Write the mean and stddev to the file.
@ -77,11 +77,13 @@ class FeatureNormalizer(object):
""" """
np.savez(filepath, mean=self._mean, std=self._std) np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath): def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file.""" """Load mean and std from file."""
npzfile = np.load(filepath) npzfile = np.load(filepath)
self._mean = npzfile["mean"] self._mean = npzfile["mean"]
self._std = npzfile["std"] std = npzfile["std"]
std = np.clip(std, eps, None)
self._istd = 1.0 / std
def _compute_mean_std(self, manifest_path, featurize_func, num_samples): def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
"""Compute mean and std from randomly sampled instances.""" """Compute mean and std from randomly sampled instances."""
@ -92,6 +94,6 @@ class FeatureNormalizer(object):
features.append( features.append(
featurize_func( featurize_func(
AudioSegment.from_file(instance["audio_filepath"]))) AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features) features = np.hstack(features) #(D, T)
self._mean = np.mean(features, axis=1).reshape([-1, 1]) self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1)
self._std = np.std(features, axis=1).reshape([-1, 1]) self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1)

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains data helper functions.""" """Contains data helper functions."""
import numpy as np import numpy as np
import math import math
import json import json
@ -20,11 +19,19 @@ import codecs
import os import os
import tarfile import tarfile
import time import time
import logging
from threading import Thread from threading import Thread
from multiprocessing import Process, Manager, Value from multiprocessing import Process, Manager, Value
from paddle.dataset.common import md5file from paddle.dataset.common import md5file
logger = logging.getLogger(__name__)
__all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio"
]
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file. """Load and parse manifest file.
@ -134,3 +141,102 @@ def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
return np.maximum( return np.maximum(
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
1.0), -1.0) 1.0), -1.0)
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
npzfile = np.load(npz_cmvn_file)
means = npzfile["mean"]
std = npzfile["std"]
std = np.clip(std, eps, None)
variance = 1.0 / std
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file: str, filetype: str):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert filetype in ['npz', 'json', 'kaldi'], filetype
filetype = filetype.lower()
if filetype == "json":
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filtype == "npz":
cmvn = _load_npz_cmvn(cmvn_file)
else:
raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1]

@ -29,7 +29,6 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from paddle.nn.utils.rnn import pad_sequence
from deepspeech.modules.mask import make_pad_mask from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores from deepspeech.modules.mask import mask_finished_scores
@ -42,13 +41,15 @@ from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss
from deepspeech.frontend.utility import load_cmvn
from deepspeech.utils import checkpoint from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils.cmvn import load_cmvn
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
from deepspeech.utils.tensor_utils import IGNORE_ID from deepspeech.utils.tensor_utils import IGNORE_ID
from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -635,7 +636,7 @@ class U2TransformerModel(U2Model):
def __init__(configs: dict): def __init__(configs: dict):
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn']) configs['cmvn_file_type'])
global_cmvn = GlobalCMVN( global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else: else:
@ -666,7 +667,7 @@ class U2ConformerModel(U2Model):
def __init__(configs: dict): def __init__(configs: dict):
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn']) configs['cmvn_file_type'])
global_cmvn = GlobalCMVN( global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float()) paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else: else:

@ -21,7 +21,11 @@ from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['sequence_mask'] __all__ = [
'sequence_mask', "make_pad_mask", "make_non_pad_mask", "subsequent_mask",
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
"mask_finished_preds"
]
def sequence_mask(x_len, max_len=None, dtype='float32'): def sequence_mask(x_len, max_len=None, dtype='float32'):
@ -93,7 +97,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0], [1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]] [1, 1, 0, 0, 0]]
""" """
return ~make_pad_mask(lengths) return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor: def subsequent_mask(size: int) -> paddle.Tensor:

@ -1,93 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import logging
import numpy as np
logger = logging.getLogger(__name__)
__all__ = ['load_cmvn']
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]

@ -20,35 +20,70 @@ import paddle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["pad_list", "add_sos_eos", "th_accuracy"] __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
IGNORE_ID = -1 IGNORE_ID = -1
def pad_list(xs: List[paddle.Tensor], pad_value: int): def pad_sequence(sequences: List[paddle.Tensor],
"""Perform padding for the list of tensors. batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args: Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. sequences (list[Tensor]): list of variable length sequences.
pad_value (float): Value for padding. batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns: Returns:
Tensor: Padded tensor (B, Tmax, `*`). Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Examples: Tensor of size ``B x T x *`` otherwise
>>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
""" """
n_batch = len(xs)
max_len = max([x.size(0) for x in xs]) # assuming trailing dimensions and type of all the Tensors
pad = paddle.zeros(n_batch, max_len, dtype=xs[0].dtype) # in sequences are same and fetching those from sequences[0]
pad = pad.fill_(pad_value) max_size = sequences[0].size()
for i in range(n_batch): trailing_dims = max_size[1:]
pad[i, :xs[i].size(0)] = xs[i] max_len = max([s.size(0) for s in sequences])
if batch_first:
return pad out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,

@ -1 +1,2 @@
* s0 for deepspeech2 * s0 for deepspeech2
* s1 for U2

@ -0,0 +1,105 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.models.deepspeech2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
max_len = 64
#(B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(
max_len, size=self.batch_size, dtype='int32')
audio_len[-1] = max_len
#(B, U)
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * self.batch_size, dtype='int32')
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
def test_ds2_1(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_2(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_3(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_4(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=True, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_5(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,39 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.modules.mask import sequence_mask
from deepspeech.modules.mask import make_non_pad_mask
class TestU2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.lengths = paddle.to_tensor([5, 3, 2])
self.masks = np.array(
[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]], )
def test_sequence_mask(self):
res = sequence_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
if __name__ == '__main__':
unittest.main()

@ -1,99 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from deepspeech.models.deepspeech2 import DeepSpeech2Model
if __name__ == '__main__':
batch_size = 2
feat_dim = 161
max_len = 100
audio = np.random.randn(batch_size, feat_dim, max_len)
audio_len = np.random.randint(100, size=batch_size, dtype='int32')
audio_len[-1] = 100
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32')
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len, dtype='int64')
text = paddle.to_tensor(text, dtype='int32')
text_len = paddle.to_tensor(text_len, dtype='int64')
print(audio.shape)
print(audio_len.shape)
print(text.shape)
print(text_len.shape)
print("-----------------")
model = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model2 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=False, )
logits, probs, logits_len = model2(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model3 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True, )
logits, probs, logits_len = model3(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model4 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=True, )
logits, probs, logits_len = model4(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model5 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model5(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")

@ -0,0 +1,60 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.models.u2 import U2TransformerModel
from deepspeech.models.u2 import U2ConformerModel
class TestU2Model(unittest.TestCase):
def setUp(self):
batch_size = 2
feat_dim = 161
max_len = 100
audio = np.random.randn(batch_size, feat_dim, max_len)
audio_len = np.random.randint(100, size=batch_size, dtype='int32')
audio_len[-1] = 100
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32')
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
print(audio.shape)
print(audio_len.shape)
print(text.shape)
print(text_len.shape)
print("-----------------")
def test_ds2_1(self):
model = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(self.audio, self.audio_len, self.text,
self.text_len)
print('probs.shape', probs.shape)
print("-----------------")
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save