Complete the modification according to the comments

, but still need to be unit tested
pull/735/head
huangyuxin 3 years ago
parent e4ef8ed31e
commit 2537221b61

@ -29,8 +29,6 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
@ -38,6 +36,8 @@ from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
logger = Log(__name__).getlog()
@ -128,9 +128,7 @@ class DeepSpeech2Trainer(Trainer):
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
use_gru=config.model.use_gru)
if self.parallel:
model = paddle.DataParallel(model)
@ -376,8 +374,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
use_gru=config.model.use_gru)
self.model = model
logger.info("Setup model!")

@ -19,12 +19,8 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
class Conv2dSubsampling4Online(Conv2dSubsampling4):
def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
super().__init__(idim, odim, dropout_rate, pos_enc_class)
def __init__(self, idim: int, odim: int, dropout_rate: float):
super().__init__(idim, odim, dropout_rate, None)
self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim
def forward(self, x: paddle.Tensor,

@ -36,16 +36,17 @@ class CRNNEncoder(nn.Layer):
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True):
use_gru=False):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
@ -54,7 +55,6 @@ class CRNNEncoder(nn.Layer):
self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
rnn_direction = 'forward'
layernorm_size = rnn_size
if use_gru == True:
@ -99,21 +99,18 @@ class CRNNEncoder(nn.Layer):
def output_size(self):
return self.fc_layers_size_list[-1]
def forward(self, audio, audio_len):
def forward(self, x, x_lens):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
x (Tensor): [B, T_input, D]
x_lens (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x (Tensor): encoder outputs, [B, T_output, D]
x_lens (Tensor): encoder length, [B]
rnn_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers
"""
# [B, T, D]
x = audio
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
# convert data from convolution feature map to sequence of vectors
@ -123,16 +120,47 @@ class CRNNEncoder(nn.Layer):
#x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, output_state = self.rnn[0](x, None, x_lens)
init_state = None
rnn_final_state_list = []
x, final_state = self.rnn[0](x, init_state, x_lens)
rnn_final_state_list.append(final_state)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D]
x, final_state = self.rnn[i](x, init_state, x_lens) #[B, T, D]
rnn_final_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
return x, x_lens
return x, x_lens, rnn_final_state_list
def forward(self, x, x_lens, init_state_list):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_chunk_size, D]
x_lens (Tensor): [B]
init_state_list (list of Tensors): [ num_directions, batch_size, hidden_size] * num_rnn_layers
Returns:
x (Tensor): encoder outputs, [B, chunk_size, D]
x_lens (Tensor): encoder length, [B]
rnn_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers
"""
rnn_final_state_list = []
x, final_state = self.rnn[0](x, init_state_list[0], x_lens)
rnn_final_state_list.append(final_state)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
rnn_final_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
return x, x_lens, rnn_final_state_list
class DeepSpeech2ModelOnline(nn.Layer):
@ -156,9 +184,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
@ -175,7 +200,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
if config is not None:
config.merge_from_other_cfg(default)
@ -187,21 +211,21 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True):
use_gru=False):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
use_gru=use_gru)
assert (self.encoder.output_size == fc_layers_size_list[-1])
self.decoder = CTCDecoder(
@ -224,7 +248,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
eouts, eouts_len, rnn_final_state_list = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@ -271,10 +295,10 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
use_gru=config.model.use_gru)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -289,20 +313,20 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True):
use_gru=False):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
use_gru=use_gru)
def forward(self, audio, audio_len):
"""export model function
@ -314,6 +338,26 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
eouts, eouts_len, rnn_final_state_list = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
def forward(self, eouts_chunk_prefix, eouts_chunk_lens_prefix, audio_chunk,
audio_chunk_len, init_state_list):
"""export model function
Args:
audio_chunk (Tensor): [B, T, D]
audio_chunk_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts_chunk, eouts_chunk_lens, rnn_final_state_list = self.encoder(
audio_chunk, audio_chunk_len, init_state_list)
eouts_chunk_new_prefix = paddle.concat(
[eouts_chunk_prefix, eouts_chunk], axis=1)
eouts_chunk_lens_new_prefix = paddle.add(eouts_chunk_lens_prefix,
eouts_chunk_lens)
probs_chunk = self.decoder.softmax(eouts_chunk_new_prefix)
return probs_chunk, eouts_chunk_new_prefix, eouts_chunk_lens_new_prefix, rnn_final_state_list

@ -1,314 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.ModuleList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.

Loading…
Cancel
Save