diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 7d0a26d7..8b47892a 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -29,8 +29,6 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2Model -#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline -#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate @@ -38,6 +36,8 @@ from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log +#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline logger = Log(__name__).getlog() @@ -128,9 +128,7 @@ class DeepSpeech2Trainer(Trainer): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - + use_gru=config.model.use_gru) if self.parallel: model = paddle.DataParallel(model) @@ -376,8 +374,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + use_gru=config.model.use_gru) self.model = model logger.info("Setup model!") diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py index 13c3d330..13c35ef2 100644 --- a/deepspeech/models/ds2_online/conv.py +++ b/deepspeech/models/ds2_online/conv.py @@ -19,12 +19,8 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4 class Conv2dSubsampling4Online(Conv2dSubsampling4): - def __init__(self, - idim: int, - odim: int, - dropout_rate: float, - pos_enc_class: nn.Layer=PositionalEncoding): - super().__init__(idim, odim, dropout_rate, pos_enc_class) + def __init__(self, idim: int, odim: int, dropout_rate: float): + super().__init__(idim, odim, dropout_rate, None) self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim def forward(self, x: paddle.Tensor, diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 4fa6da0d..e9e81d5d 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -36,16 +36,17 @@ class CRNNEncoder(nn.Layer): num_conv_layers=2, num_rnn_layers=4, rnn_size=1024, + rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False, - share_rnn_weights=True): + use_gru=False): super().__init__() self.rnn_size = rnn_size self.feat_size = feat_size # 161 for linear self.dict_size = dict_size self.num_rnn_layers = num_rnn_layers self.num_fc_layers = num_fc_layers + self.rnn_direction = rnn_direction self.fc_layers_size_list = fc_layers_size_list self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) @@ -54,7 +55,6 @@ class CRNNEncoder(nn.Layer): self.rnn = nn.LayerList() self.layernorm_list = nn.LayerList() self.fc_layers_list = nn.LayerList() - rnn_direction = 'forward' layernorm_size = rnn_size if use_gru == True: @@ -99,21 +99,18 @@ class CRNNEncoder(nn.Layer): def output_size(self): return self.fc_layers_size_list[-1] - def forward(self, audio, audio_len): + def forward(self, x, x_lens): """Compute Encoder outputs Args: - audio (Tensor): [B, Tmax, D] - text (Tensor): [B, Umax] - audio_len (Tensor): [B] - text_len (Tensor): [B] + x (Tensor): [B, T_input, D] + x_lens (Tensor): [B] Returns: - x (Tensor): encoder outputs, [B, T, D] + x (Tensor): encoder outputs, [B, T_output, D] x_lens (Tensor): encoder length, [B] + rnn_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers """ # [B, T, D] - x = audio - x_lens = audio_len # convolution group x, x_lens = self.conv(x, x_lens) # convert data from convolution feature map to sequence of vectors @@ -123,16 +120,47 @@ class CRNNEncoder(nn.Layer): #x = x.reshape([0, 0, -1]) #[B, T, C*D] # remove padding part - x, output_state = self.rnn[0](x, None, x_lens) + init_state = None + rnn_final_state_list = [] + x, final_state = self.rnn[0](x, init_state, x_lens) + rnn_final_state_list.append(final_state) x = self.layernorm_list[0](x) for i in range(1, self.num_rnn_layers): - x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D] + x, final_state = self.rnn[i](x, init_state, x_lens) #[B, T, D] + rnn_final_state_list.append(final_state) x = self.layernorm_list[i](x) for i in range(self.num_fc_layers): x = self.fc_layers_list[i](x) x = F.relu(x) - return x, x_lens + return x, x_lens, rnn_final_state_list + + def forward(self, x, x_lens, init_state_list): + """Compute Encoder outputs + + Args: + x (Tensor): [B, feature_chunk_size, D] + x_lens (Tensor): [B] + init_state_list (list of Tensors): [ num_directions, batch_size, hidden_size] * num_rnn_layers + Returns: + x (Tensor): encoder outputs, [B, chunk_size, D] + x_lens (Tensor): encoder length, [B] + rnn_final_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers + """ + rnn_final_state_list = [] + x, final_state = self.rnn[0](x, init_state_list[0], x_lens) + rnn_final_state_list.append(final_state) + x = self.layernorm_list[0](x) + for i in range(1, self.num_rnn_layers): + x, final_state = self.rnn[i](x, init_state_list[i], + x_lens) #[B, T, D] + rnn_final_state_list.append(final_state) + x = self.layernorm_list[i](x) + + for i in range(self.num_fc_layers): + x = self.fc_layers_list[i](x) + x = F.relu(x) + return x, x_lens, rnn_final_state_list class DeepSpeech2ModelOnline(nn.Layer): @@ -156,9 +184,6 @@ class DeepSpeech2ModelOnline(nn.Layer): :type rnn_size: int :param use_gru: Use gru if set True. Use simple rnn if set False. :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. :type share_weights: bool :return: A tuple of an output unnormalized log probability layer ( before softmax) and a ctc cost layer. @@ -175,7 +200,6 @@ class DeepSpeech2ModelOnline(nn.Layer): num_fc_layers=2, fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. )) if config is not None: config.merge_from_other_cfg(default) @@ -187,21 +211,21 @@ class DeepSpeech2ModelOnline(nn.Layer): num_conv_layers=2, num_rnn_layers=3, rnn_size=1024, + rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False, - share_rnn_weights=True): + use_gru=False): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, dict_size=dict_size, num_conv_layers=num_conv_layers, num_rnn_layers=num_rnn_layers, + rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + use_gru=use_gru) assert (self.encoder.output_size == fc_layers_size_list[-1]) self.decoder = CTCDecoder( @@ -224,7 +248,7 @@ class DeepSpeech2ModelOnline(nn.Layer): Returns: loss (Tenosr): [1] """ - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, rnn_final_state_list = self.encoder(audio, audio_len) loss = self.decoder(eouts, eouts_len, text, text_len) return loss @@ -271,10 +295,10 @@ class DeepSpeech2ModelOnline(nn.Layer): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, + rnn_direction=config.model.rnn_direction, num_fc_layers=config.model.num_fc_layers, fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + use_gru=config.model.use_gru) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -289,20 +313,20 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): num_conv_layers=2, num_rnn_layers=3, rnn_size=1024, + rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False, - share_rnn_weights=True): + use_gru=False): super().__init__( feat_size=feat_size, dict_size=dict_size, num_conv_layers=num_conv_layers, num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, + rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + use_gru=use_gru) def forward(self, audio, audio_len): """export model function @@ -314,6 +338,26 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): Returns: probs: probs after softmax """ - eouts, eouts_len = self.encoder(audio, audio_len) + eouts, eouts_len, rnn_final_state_list = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) return probs + + def forward(self, eouts_chunk_prefix, eouts_chunk_lens_prefix, audio_chunk, + audio_chunk_len, init_state_list): + """export model function + + Args: + audio_chunk (Tensor): [B, T, D] + audio_chunk_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts_chunk, eouts_chunk_lens, rnn_final_state_list = self.encoder( + audio_chunk, audio_chunk_len, init_state_list) + eouts_chunk_new_prefix = paddle.concat( + [eouts_chunk_prefix, eouts_chunk], axis=1) + eouts_chunk_lens_new_prefix = paddle.add(eouts_chunk_lens_prefix, + eouts_chunk_lens) + probs_chunk = self.decoder.softmax(eouts_chunk_new_prefix) + return probs_chunk, eouts_chunk_new_prefix, eouts_chunk_lens_new_prefix, rnn_final_state_list diff --git a/deepspeech/models/ds2_online/rnn.py b/deepspeech/models/ds2_online/rnn.py deleted file mode 100644 index 01b55c4a..00000000 --- a/deepspeech/models/ds2_online/rnn.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -from deepspeech.modules.activation import brelu -from deepspeech.modules.mask import make_non_pad_mask -from deepspeech.utils.log import Log - -logger = Log(__name__).getlog() - -__all__ = ['RNNStack'] - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size: int, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCell(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size: int, - hidden_size: int, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.tanh - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h_gates = h_gates + self.bias_hh - - x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) - h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) - - r = self._gate_activation(x_r + h_r) - z = self._gate_activation(x_z + h_z) - c = self._activation(x_c + r * h_c) # apply reset gate after mm - h = (pre_hidden - c) * z + c - # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru - - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int, share_weights: bool): - super().__init__() - self.share_weights = share_weights - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size: int, h_size: int): - super().__init__() - hidden_size = h_size * 3 - - self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - - self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, - i_size: int, - h_size: int, - num_stacks: int, - use_gru: bool, - share_rnn_weights: bool): - super().__init__() - rnn_stacks = [] - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) - else: - rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - self.rnn_stacks = nn.ModuleList(rnn_stacks) - - def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - masks = make_non_pad_mask(x_len) #[B, T] - masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) - return x, x_len diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 5aa2fd8e..40fa7b00 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling4 object. - + Args: idim (int): Input dimension. odim (int): Output dimension. @@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling6 object. - + Args: idim (int): Input dimension. odim (int): Output dimension. @@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling): dropout_rate: float, pos_enc_class: nn.Layer=PositionalEncoding): """Construct an Conv2dSubsampling8 object. - + Args: idim (int): Input dimension. odim (int): Output dimension.