diff --git a/model_utils/network.py b/model_utils/network.py index 3a4f1dc3..dbbf75e7 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -59,50 +59,53 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, return padding_reset -def simple_rnn(input, size, param_attr=None, bias_attr=None, is_reverse=False): - '''A simple rnn layer. - :param input: input layer. - :type input: Variable - :param size: Dimension of RNN cells. - :type size: int +class RNNCell(fluid.layers.RNNCell): + '''A simple rnn cell. + :param hidden_size: Dimension of RNN cells. + :type hidden_size: int :param param_attr: Parameter properties of hidden layer weights that can be learned :type param_attr: ParamAttr :param bias_attr: Bias properties of hidden layer weights that can be learned :type bias_attr: ParamAttr - :param is_reverse: Whether to calculate the inverse RNN - :type is_reverse: bool - :return: A simple RNN layer. - :rtype: Variable + :param hidden_activation: Activation for hidden cell + :type hidden_activation: Activation + :param activation: Activation for output + :type activation: Activation + :param name: Name of cell + :type name: string ''' - if is_reverse: - input = fluid.layers.sequence_reverse(x=input) - - pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) - input, length = fluid.layers.sequence_pad(input, pad_value) - rnn = fluid.layers.StaticRNN() - input = fluid.layers.transpose(input, [1, 0, 2]) - with rnn.step(): - in_ = rnn.step_input(input) - mem = rnn.memory(shape=[-1, size], batch_ref=in_) - out = fluid.layers.fc( - input=mem, - size=size, - act=None, - param_attr=param_attr, - bias_attr=bias_attr) - out = fluid.layers.elementwise_add(out, in_) - out = fluid.layers.brelu(out) - rnn.update_memory(mem, out) - rnn.output(out) - - out = rnn() - out = fluid.layers.transpose(out, [1, 0, 2]) - out = fluid.layers.sequence_unpad(x=out, length=length) - - if is_reverse: - out = fluid.layers.sequence_reverse(x=out) - return out + + def __init__(self, + hidden_size, + param_attr=None, + bias_attr=None, + hidden_activation=None, + activation=None, + dtype="float32", + name="RNNCell"): + self.hidden_size = hidden_size + self.param_attr = param_attr + self.bias_attr = bias_attr + self.hidden_activation = hidden_activation + self.activation = activation or fluid.layers.brelu + self.name = name + + def call(self, inputs, states): + new_hidden = fluid.layers.fc( + input=states, + size=self.hidden_size, + act=self.hidden_activation, + param_attr=self.param_attr, + bias_attr=self.bias_attr) + new_hidden = fluid.layers.elementwise_add(new_hidden, inputs) + new_hidden = self.activation(new_hidden) + + return new_hidden, new_hidden + + @property + def state_shape(self): + return [self.hidden_size] def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): @@ -137,20 +140,32 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): bias_attr=fluid.ParamAttr(name=name + '_batch_norm_bias'), moving_mean_name=name + '_batch_norm_moving_mean', moving_variance_name=name + '_batch_norm_moving_variance') - #forward and backword in time - forward_rnn = simple_rnn( - input=input_proj_bn, - size=size, + #forward and backword in time + forward_cell = RNNCell( + hidden_size=size, + activation=fluid.layers.brelu, param_attr=fluid.ParamAttr(name=name + '_forward_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias'), - is_reverse=False) + bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias')) - reverse_rnn = simple_rnn( - input=input_proj_bn, - size=size, + pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) + input, length = fluid.layers.sequence_pad(input_proj_bn, pad_value) + forward_rnn, _ = fluid.layers.rnn( + cell=forward_cell, inputs=input, time_major=False, is_reverse=False) + forward_rnn = fluid.layers.sequence_unpad(x=forward_rnn, length=length) + + reverse_cell = RNNCell( + hidden_size=size, + activation=fluid.layers.brelu, param_attr=fluid.ParamAttr(name=name + '_reverse_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias'), + bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias')) + input, length = fluid.layers.sequence_pad(input_proj_bn, pad_value) + reverse_rnn, _ = fluid.layers.rnn( + cell=reverse_cell, + inputs=input, + sequence_length=length, + time_major=False, is_reverse=True) + reverse_rnn = fluid.layers.sequence_unpad(x=reverse_rnn, length=length) else: input_proj_forward = fluid.layers.fc( @@ -183,18 +198,32 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): moving_mean_name=name + '_reverse_batch_norm_moving_mean', moving_variance_name=name + '_reverse_batch_norm_moving_variance') # forward and backward in time - forward_rnn = simple_rnn( - input=input_proj_bn_forward, - size=size, + forward_cell = RNNCell( + hidden_size=size, + activation=fluid.layers.brelu, param_attr=fluid.ParamAttr(name=name + '_forward_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias'), - is_reverse=False) - reverse_rnn = simple_rnn( - input=input_proj_bn_backward, - size=size, + bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias')) + + pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) + input, length = fluid.layers.sequence_pad(input_proj_bn, pad_value) + forward_rnn, _ = fluid.layers.rnn( + cell=forward_cell, inputs=input, time_major=False, is_reverse=False) + forward_rnn = fluid.layers.sequence_unpad(x=forward_rnn, length=length) + + reverse_cell = RNNCell( + hidden_size=size, + activation=fluid.layers.brelu, param_attr=fluid.ParamAttr(name=name + '_reverse_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias'), + bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias')) + input, length = fluid.layers.sequence_pad(input_proj_bn, pad_value) + reverse_rnn, _ = fluid.layers.rnn( + cell=reverse_cell, + inputs=input, + sequence_length=length, + time_major=False, is_reverse=True) + reverse_rnn = fluid.layers.sequence_unpad(x=reverse_rnn, length=length) + out = fluid.layers.concat(input=[forward_rnn, reverse_rnn], axis=1) return out