refactor cmvn, test

pull/578/head
Hui Zhang 5 years ago
parent 5659bd2386
commit 907c93392f

@ -414,73 +414,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static)
########### hcak paddle.nn.utils #############
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
if not hasattr(paddle.nn.utils, 'rnn.pad_sequence'):
logger.warn(
"register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!"
)
setattr(paddle.nn.utils, 'rnn.pad_sequence', pad_sequence)
setattr(paddle.jit, 'export', paddle.jit.to_static)

@ -57,17 +57,17 @@ class FeatureNormalizer(object):
else:
self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-14):
def apply(self, features):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray
:type features: ndarray, shape (D, T)
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
:rtype: ndarray
"""
return (features - self._mean) / (self._std + eps)
return (features - self._mean) * self._istd
def write_to_file(self, filepath):
"""Write the mean and stddev to the file.
@ -77,11 +77,13 @@ class FeatureNormalizer(object):
"""
np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath):
def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file."""
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
self._std = npzfile["std"]
std = npzfile["std"]
std = np.clip(std, eps, None)
self._istd = 1.0 / std
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
"""Compute mean and std from randomly sampled instances."""
@ -92,6 +94,6 @@ class FeatureNormalizer(object):
features.append(
featurize_func(
AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features)
self._mean = np.mean(features, axis=1).reshape([-1, 1])
self._std = np.std(features, axis=1).reshape([-1, 1])
features = np.hstack(features) #(D, T)
self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1)
self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1)

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import numpy as np
import math
import json
@ -20,11 +19,19 @@ import codecs
import os
import tarfile
import time
import logging
from threading import Thread
from multiprocessing import Process, Manager, Value
from paddle.dataset.common import md5file
logger = logging.getLogger(__name__)
__all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio"
]
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file.
@ -134,3 +141,102 @@ def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
return np.maximum(
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
1.0), -1.0)
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
npzfile = np.load(npz_cmvn_file)
means = npzfile["mean"]
std = npzfile["std"]
std = np.clip(std, eps, None)
variance = 1.0 / std
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file: str, filetype: str):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert filetype in ['npz', 'json', 'kaldi'], filetype
filetype = filetype.lower()
if filetype == "json":
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filtype == "npz":
cmvn = _load_npz_cmvn(cmvn_file)
else:
raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1]

@ -29,7 +29,6 @@ from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddle.nn.utils.rnn import pad_sequence
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
@ -42,13 +41,15 @@ from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss
from deepspeech.frontend.utility import load_cmvn
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.cmvn import load_cmvn
from deepspeech.utils.utility import log_add
from deepspeech.utils.tensor_utils import IGNORE_ID
from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
logger = logging.getLogger(__name__)
@ -635,7 +636,7 @@ class U2TransformerModel(U2Model):
def __init__(configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn'])
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:
@ -666,7 +667,7 @@ class U2ConformerModel(U2Model):
def __init__(configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn'])
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:

@ -21,7 +21,11 @@ from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
__all__ = ['sequence_mask']
__all__ = [
'sequence_mask', "make_pad_mask", "make_non_pad_mask", "subsequent_mask",
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
"mask_finished_preds"
]
def sequence_mask(x_len, max_len=None, dtype='float32'):
@ -93,7 +97,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor:

@ -1,93 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import logging
import numpy as np
logger = logging.getLogger(__name__)
__all__ = ['load_cmvn']
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]

@ -20,35 +20,70 @@ import paddle
logger = logging.getLogger(__name__)
__all__ = ["pad_list", "add_sos_eos", "th_accuracy"]
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
IGNORE_ID = -1
def pad_list(xs: List[paddle.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
n_batch = len(xs)
max_len = max([x.size(0) for x in xs])
pad = paddle.zeros(n_batch, max_len, dtype=xs[0].dtype)
pad = pad.fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,

@ -1 +1,2 @@
* s0 for deepspeech2
* s1 for U2

@ -0,0 +1,105 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.models.deepspeech2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
max_len = 64
#(B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(
max_len, size=self.batch_size, dtype='int32')
audio_len[-1] = max_len
#(B, U)
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * self.batch_size, dtype='int32')
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
def test_ds2_1(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_2(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_3(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_4(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=True, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_5(self):
model = DeepSpeech2Model(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,39 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.modules.mask import sequence_mask
from deepspeech.modules.mask import make_non_pad_mask
class TestU2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.lengths = paddle.to_tensor([5, 3, 2])
self.masks = np.array(
[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]], )
def test_sequence_mask(self):
res = sequence_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
if __name__ == '__main__':
unittest.main()

@ -1,99 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from deepspeech.models.deepspeech2 import DeepSpeech2Model
if __name__ == '__main__':
batch_size = 2
feat_dim = 161
max_len = 100
audio = np.random.randn(batch_size, feat_dim, max_len)
audio_len = np.random.randint(100, size=batch_size, dtype='int32')
audio_len[-1] = 100
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32')
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len, dtype='int64')
text = paddle.to_tensor(text, dtype='int32')
text_len = paddle.to_tensor(text_len, dtype='int64')
print(audio.shape)
print(audio_len.shape)
print(text.shape)
print(text_len.shape)
print("-----------------")
model = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model2 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=False, )
logits, probs, logits_len = model2(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model3 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True, )
logits, probs, logits_len = model3(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model4 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=True,
share_rnn_weights=True, )
logits, probs, logits_len = model4(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
model5 = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model5(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")

@ -0,0 +1,60 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from deepspeech.models.u2 import U2TransformerModel
from deepspeech.models.u2 import U2ConformerModel
class TestU2Model(unittest.TestCase):
def setUp(self):
batch_size = 2
feat_dim = 161
max_len = 100
audio = np.random.randn(batch_size, feat_dim, max_len)
audio_len = np.random.randint(100, size=batch_size, dtype='int32')
audio_len[-1] = 100
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32')
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
print(audio.shape)
print(audio_len.shape)
print(text.shape)
print(text_len.shape)
print("-----------------")
def test_ds2_1(self):
model = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(self.audio, self.audio_len, self.text,
self.text_len)
print('probs.shape', probs.shape)
print("-----------------")
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save