complete the encoder of ds_online

pull/735/head
huangyuxin 3 years ago
parent 269eecb3be
commit 4f392e28b1

@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
@ -120,14 +122,24 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if (config.model.apply_online == False):
model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel:
model = paddle.DataParallel(model)
@ -329,8 +341,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1)
def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
if self.config.model.apply_online == False:
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static(
@ -367,14 +384,25 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if config.model.apply_online == False:
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model
logger.info("Setup model!")

@ -1,7 +0,0 @@
from .deepspeech2 import DeepSpeech2Model
from .deepspeech2 import DeepSpeech2InferModel
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']

@ -1,172 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from paddle.nn import functional as F
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type, relu|brelu
:type act: string
:return: Batch norm layer after convolution layer.
:rtype: Variable
"""
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
padding, act):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(padding) == 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2D(
num_channels_in,
num_channels_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_attr=None,
bias_attr=False,
data_format='NCHW')
self.bn = nn.BatchNorm2D(
num_channels_out,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=(41, 11), #[D, T]
stride=(2, 3),
padding=(20, 5),
act='brelu')
out_channel = 32
convs = [
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
]
self.conv_stack = nn.LayerList(convs)
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len

@ -1,314 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.ModuleList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -1,7 +1,7 @@
from .deepspeech2 import DeepSpeech2Model
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2ModelOnline
from .deepspeech2 import DeepSpeech2InferModelOnline
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']

@ -16,23 +16,27 @@ from typing import Optional
import paddle
from paddle import nn
import paddle.nn.functional as F
from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.models.ds2_online.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.models.ds2.rnn import RNNStack
from deepspeech.models.ds2_online.rnn import RNNStack
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from paddle.nn import LSTM, GRU
from paddle.nn import LSTM, GRU, Linear
from paddle.nn import LayerNorm
from paddle.nn import LayerList
from paddle.fluid.layers import fc
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode']
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
class CRNNEncoder(nn.Layer):
@ -40,31 +44,28 @@ class CRNNEncoder(nn.Layer):
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
num_rnn_layers=4,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True,
apply_online=True):
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.apply_online = apply_online
self.num_fc_layers = num_fc_layers
self.fc_layers_size_list = fc_layers_size_list
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = LayerList()
self.layernorm_list = LayerList()
if (apply_online == True):
rnn_direction = 'forward'
layernorm_size = rnn_size
else:
rnn_direction = 'bidirect'
layernorm_size = 2 * rnn_size
self.fc_layers_list = LayerList()
rnn_direction = 'forward'
layernorm_size = rnn_size
if use_gru == True:
self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
@ -78,20 +79,14 @@ class CRNNEncoder(nn.Layer):
for i in range(1, num_rnn_layers):
self.rnn.append(LSTM(input_size=layernorm_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(layernorm_size))
"""
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
"""
fc_input_size = layernorm_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i]
@property
def output_size(self):
if (self.apply_online == True):
return self.rnn_size
else:
return 2 * self.rnn_size
return self.fc_layers_size_list[-1]
def forward(self, audio, audio_len):
"""Compute Encoder outputs
@ -126,14 +121,15 @@ class CRNNEncoder(nn.Layer):
for i in range(1, self.num_rnn_layers):
x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D]
x = self.layernorm_list[i](x)
"""
x, x_lens = self.rnn(x, x_lens)
"""
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
@ -167,8 +163,10 @@ class DeepSpeech2Model(nn.Layer):
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
num_rnn_layers=4, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
num_fc_layers=2,
fc_layers_size_list = [512,256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
@ -182,23 +180,22 @@ class DeepSpeech2Model(nn.Layer):
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True,
apply_online = True):
share_rnn_weights=True):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
apply_online=apply_online)
if (apply_online == True):
assert (self.encoder.output_size == rnn_size)
else:
assert (self.encoder.output_size == 2 * rnn_size)
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == fc_layers_size_list[-1])
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
@ -267,9 +264,10 @@ class DeepSpeech2Model(nn.Layer):
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
share_rnn_weights=config.model.share_rnn_weights)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -277,25 +275,27 @@ class DeepSpeech2Model(nn.Layer):
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
share_rnn_weights=True,
apply_online = True):
share_rnn_weights=True):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
apply_online=apply_online)
share_rnn_weights=share_rnn_weights)
def forward(self, audio, audio_len):
"""export model function

@ -2,7 +2,7 @@
set -e
source path.sh
gpus=0,1,2,3
gpus=2,3,4,5
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml

@ -16,8 +16,8 @@ import unittest
import numpy as np
import paddle
from deepspeech.models.deepspeech2 import DeepSpeech2Model
#from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline as DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self):

Loading…
Cancel
Save