From c329c5dea1e9f5f6e8759ad4be0c9b5e62c6a288 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sun, 7 Feb 2021 03:41:23 +0000 Subject: [PATCH] model test pass --- model_utils/network2.py | 114 ++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/model_utils/network2.py b/model_utils/network2.py index 4e1eb55f7..8cbbbf818 100644 --- a/model_utils/network2.py +++ b/model_utils/network2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import collections import numpy as np import paddle @@ -19,13 +20,17 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +__all__ = ['DeepSpeech2'] + def brelu(x, t_min=0.0, t_max=24.0, name=None): - return paddle.min(paddle.max(x, t_min), t_max) + t_min = paddle.to_tensor(t_min) + t_max = paddle.to_tensor(t_max) + return x.maximum(t_min).minimum(t_max) def sequence_mask(x_len, max_len=None, dtype='float32'): - max_len = (max_len or paddle.max(x)) + max_len = max_len or x_len.max() x_len = paddle.unsqueeze(x_len, -1) row_vector = paddle.arange(max_len) mask = row_vector < x_len @@ -76,14 +81,13 @@ class ConvBn(nn.Layer): padding=padding, weight_attr=None, bias_attr=None, - data_format='NCHW', ) + data_format='NCHW') + self.bn = nn.BatchNorm2D( - num_channels=num_channels_out, - param_attr=None, + num_channels_out, + weight_attr=None, bias_attr=None, - moving_mean_name=None, - moving_variance_name=None, - data_format='NCHW', ) + data_format='NCHW') self.act = paddle.relu if act == 'relu' else brelu def forward(self, x, x_len): @@ -94,13 +98,14 @@ class ConvBn(nn.Layer): x = self.bn(x) x = self.act(x) + x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] + ) // self.stride[1] + 1 + # reset padding part to 0 masks = sequence_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] x = x.multiply(masks) - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 return x, x_len @@ -128,10 +133,12 @@ class ConvStack(nn.Layer): stride=self.stride, padding=self.padding, act='brelu', ) + + out_channel = 32 self.conv_stack = nn.LayerList([ ConvBn( num_channels_in=32, - num_channels_out=32, + num_channels_out=out_channel, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5), @@ -142,7 +149,7 @@ class ConvStack(nn.Layer): output_height = (feat_size - 1) // 2 + 1 for i in range(self.num_stacks - 1): output_height = (output_height - 1) // 2 + 1 - self.output_height = output_height + self.output_height = out_channel * output_height def forward(self, x, x_len): """ @@ -239,13 +246,14 @@ class GRUCellShare(nn.RNNCellBase): """ def __init__(self, + input_size, hidden_size, weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, name=None): - super(GRUCell, self).__init__() + super().__init__() std = 1.0 / math.sqrt(hidden_size) self.weight_hh = self.create_parameter( (3 * hidden_size, hidden_size), @@ -316,7 +324,6 @@ class BiRNNWithBN(nn.Layer): def __init__(self, i_size, h_size, share_weights): super().__init__() - self.share_weights = share_weights self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32)) if self.share_weights: @@ -344,7 +351,7 @@ class BiRNNWithBN(nn.Layer): def forward(self, x, x_len): # x, shape [B, T, D] fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_bn(x)) + bw_x = self.bw_bn(self.bw_fc(x)) fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) x = paddle.concat([fw_x, bw_x], axis=-1) @@ -367,16 +374,16 @@ class BiGRUWithBN(nn.Layer): :rtype: Variable """ - def __init__(self, i_size, act): + def __init__(self, i_size, h_size, act): super().__init__() - hidden_size = i_size * 3 + hidden_size = h_size * 3 self.fw_fc = nn.Linear(i_size, hidden_size) self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') self.bw_fc = nn.Linear(i_size, hidden_size) self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') - self.fw_cell = GRUCellShare(hidden_size) - self.bw_cell = GRUCellShare(hidden_size) + self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) self.fw_rnn = nn.RNN( self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.bw_rnn = nn.RNN( @@ -385,7 +392,7 @@ class BiGRUWithBN(nn.Layer): def forward(self, x, x_len): # x, shape [B, T, D] fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_bn(x)) + bw_x = self.bw_bn(self.bw_fc(x)) fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) x = paddle.concat([fw_x, bw_x], axis=-1) @@ -412,17 +419,20 @@ class RNNStack(nn.Layer): """ def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights): + super().__init__() self.rnn_stacks = nn.LayerList() for i in range(num_stacks): if use_gru: #default:GRU using tanh - self.rnn_stacks.append(BiGRUWithBN(size=i_size, act="relu")) + self.rnn_stacks.append( + BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu")) else: self.rnn_stacks.append( BiRNNWithBN( i_size=i_size, - size=h_size, - share_weights=share_rnn_weights, )) + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 def forward(self, x, x_len): """ @@ -471,30 +481,25 @@ class DeepSpeech2(nn.Layer): num_rnn_layers=3, rnn_size=256, use_gru=False, - share_rnn_weight=True): + share_rnn_weights=True): super().__init__() self.feat_size = feat_size # 161 for linear self.dict_size = dict_size - self.conv = ConvStack(num_conv_layers) + self.conv = ConvStack(feat_size, num_conv_layers) - i_size = self.conv.output_height(feat_size) # H after conv stack + i_size = self.conv.output_height # H after conv stack self.rnn = RNNStack( i_size=i_size, h_size=rnn_size, num_stacks=num_rnn_layers, use_gru=use_gru, - share_rnn_weights=share_rnn_weights, ) - self.fc = nn.Linaer(rnn_size * 2, dict_size + 1) + share_rnn_weights=share_rnn_weights) + self.fc = nn.Linear(rnn_size * 2, dict_size + 1) + self.loss = nn.CTCLoss(blank=dict_size, reduction='none') - def forward(self, audio, text, audio_len, text_len): - """ - audio: shape [B, D, T] - text: shape [B, T] - audio_len: shape [B] - text_len: shape [B] - """ + def predict(self, audio, audio_len): # [B, D, T] -> [B, C=1, D, T] audio = audio.unsqueeze(1) @@ -504,7 +509,7 @@ class DeepSpeech2(nn.Layer): # convert data from convolution feature map to sequence of vectors B, C, D, T = paddle.shape(x) x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - x = x.reshape([0, -1, C * D]) #[B, T, C*D] + x = x.reshape([B, T, C * D]) #[B, T, C*D] # remove padding part x, audio_len = self.rnn(x, audio_len) #[B, T, D] @@ -512,14 +517,31 @@ class DeepSpeech2(nn.Layer): logits = self.fc(x) #[B, T, V + 1] #ctcdecoder need probs, not log_probs - probs = F.log_softmax(logits) + probs = F.softmax(logits) - if not text: - return probs, None - else: - # warp-ctc do softmax on activations - # warp-ctc need activation with shape [T, B, V + 1] - logits = logits.transpose([1, 0, 2]) - ctc_loss = self.loss(logits, text, audio_len, text_len) - ctc_loss = paddle.reduce_sum(ctc_loss) - return probs, ctc_loss + return logits, probs + + @paddle.no_grad() + def infer(self, audio, audio_len): + _, probs = self.predict(audio, audio_len) + return probs + + def forward(self, audio, text, audio_len, text_len): + """ + audio: shape [B, D, T] + text: shape [B, T] + audio_len: shape [B] + text_len: shape [B] + """ + logits, probs = self.predict(audio, audio_len) + # warp-ctc do softmax on activations + # warp-ctc need activation with shape [T, B, V + 1] + logits = logits.transpose([1, 0, 2]) + print(logits.shape) + print(text.shape) + print(audio_len.shape) + print(text_len.shape) + ctc_loss = self.loss(logits, text, audio_len, text_len) + ctc_loss /= text_len # norm_by_times + ctc_loss = ctc_loss.sum() + return probs, ctc_loss