fix model and ctc

pull/522/head
Hui Zhang 5 years ago
parent a94fc3f6ed
commit 54b13722f5

@ -26,6 +26,8 @@ from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
logger = logging.getLogger(__name__)
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
@ -36,9 +38,9 @@ def ctc_loss(log_probs,
blank=0,
reduction='mean',
norm_by_times=True):
#print("my ctc loss with norm by times")
loss_out = paddle.fluid.layers.warpctc(log_probs, labels, blank, norm_by_times,
input_lengths, label_lengths)
#logger.info("my ctc loss with norm by times")
loss_out = paddle.fluid.layers.warpctc(
log_probs, labels, blank, norm_by_times, input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
@ -48,6 +50,7 @@ def ctc_loss(log_probs,
loss_out = paddle.sum(loss_out)
return loss_out
F.ctc_loss = ctc_loss
@ -216,11 +219,12 @@ class RNNCell(nn.RNNCellBase):
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = self.create_parameter(
(hidden_size, ),
bias_ih_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
# self.bias_ih = self.create_parameter(
# (hidden_size, ),
# bias_ih_attr,
# is_bias=True,
# default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
@ -287,11 +291,12 @@ class GRUCellShare(nn.RNNCellBase):
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = self.create_parameter(
(3 * hidden_size, ),
bias_ih_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
# self.bias_ih = self.create_parameter(
# (3 * hidden_size, ),
# bias_ih_attr,
# is_bias=True,
# default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
@ -301,7 +306,8 @@ class GRUCellShare(nn.RNNCellBase):
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
#self._activation = paddle.tanh
self._activation = paddle.relu
def forward(self, inputs, states=None):
if states is None:
@ -322,6 +328,8 @@ class GRUCellShare(nn.RNNCellBase):
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
#h = (1-z) * pre_hidden + z * c
return h, h
@ -353,24 +361,24 @@ class BiRNNWithBN(nn.Layer):
def __init__(self, i_size, h_size, share_weights):
super().__init__()
self.share_weights = share_weights
self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size)
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size)
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size)
self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='relu')
self.bw_cell = RNNCell(
hidden_size=h_size,
activation='relu', )
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
@ -405,10 +413,12 @@ class BiGRUWithBN(nn.Layer):
def __init__(self, i_size, h_size, act):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size)
self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size)
self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)

Loading…
Cancel
Save