|
|
@ -12,6 +12,7 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import collections
|
|
|
|
import collections
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
@ -19,13 +20,17 @@ from paddle import nn
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DeepSpeech2']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
|
|
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
|
|
return paddle.min(paddle.max(x, t_min), t_max)
|
|
|
|
t_min = paddle.to_tensor(t_min)
|
|
|
|
|
|
|
|
t_max = paddle.to_tensor(t_max)
|
|
|
|
|
|
|
|
return x.maximum(t_min).minimum(t_max)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_mask(x_len, max_len=None, dtype='float32'):
|
|
|
|
def sequence_mask(x_len, max_len=None, dtype='float32'):
|
|
|
|
max_len = (max_len or paddle.max(x))
|
|
|
|
max_len = max_len or x_len.max()
|
|
|
|
x_len = paddle.unsqueeze(x_len, -1)
|
|
|
|
x_len = paddle.unsqueeze(x_len, -1)
|
|
|
|
row_vector = paddle.arange(max_len)
|
|
|
|
row_vector = paddle.arange(max_len)
|
|
|
|
mask = row_vector < x_len
|
|
|
|
mask = row_vector < x_len
|
|
|
@ -76,14 +81,13 @@ class ConvBn(nn.Layer):
|
|
|
|
padding=padding,
|
|
|
|
padding=padding,
|
|
|
|
weight_attr=None,
|
|
|
|
weight_attr=None,
|
|
|
|
bias_attr=None,
|
|
|
|
bias_attr=None,
|
|
|
|
data_format='NCHW', )
|
|
|
|
data_format='NCHW')
|
|
|
|
|
|
|
|
|
|
|
|
self.bn = nn.BatchNorm2D(
|
|
|
|
self.bn = nn.BatchNorm2D(
|
|
|
|
num_channels=num_channels_out,
|
|
|
|
num_channels_out,
|
|
|
|
param_attr=None,
|
|
|
|
weight_attr=None,
|
|
|
|
bias_attr=None,
|
|
|
|
bias_attr=None,
|
|
|
|
moving_mean_name=None,
|
|
|
|
data_format='NCHW')
|
|
|
|
moving_variance_name=None,
|
|
|
|
|
|
|
|
data_format='NCHW', )
|
|
|
|
|
|
|
|
self.act = paddle.relu if act == 'relu' else brelu
|
|
|
|
self.act = paddle.relu if act == 'relu' else brelu
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
def forward(self, x, x_len):
|
|
|
@ -94,13 +98,14 @@ class ConvBn(nn.Layer):
|
|
|
|
x = self.bn(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.act(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
|
|
|
|
|
|
|
|
) // self.stride[1] + 1
|
|
|
|
|
|
|
|
|
|
|
|
# reset padding part to 0
|
|
|
|
# reset padding part to 0
|
|
|
|
masks = sequence_mask(x_len) #[B, T]
|
|
|
|
masks = sequence_mask(x_len) #[B, T]
|
|
|
|
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
|
|
|
|
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
|
|
|
|
x = x.multiply(masks)
|
|
|
|
x = x.multiply(masks)
|
|
|
|
|
|
|
|
|
|
|
|
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
|
|
|
|
|
|
|
|
) // self.stride[1] + 1
|
|
|
|
|
|
|
|
return x, x_len
|
|
|
|
return x, x_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -128,10 +133,12 @@ class ConvStack(nn.Layer):
|
|
|
|
stride=self.stride,
|
|
|
|
stride=self.stride,
|
|
|
|
padding=self.padding,
|
|
|
|
padding=self.padding,
|
|
|
|
act='brelu', )
|
|
|
|
act='brelu', )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_channel = 32
|
|
|
|
self.conv_stack = nn.LayerList([
|
|
|
|
self.conv_stack = nn.LayerList([
|
|
|
|
ConvBn(
|
|
|
|
ConvBn(
|
|
|
|
num_channels_in=32,
|
|
|
|
num_channels_in=32,
|
|
|
|
num_channels_out=32,
|
|
|
|
num_channels_out=out_channel,
|
|
|
|
kernel_size=(21, 11),
|
|
|
|
kernel_size=(21, 11),
|
|
|
|
stride=(2, 1),
|
|
|
|
stride=(2, 1),
|
|
|
|
padding=(10, 5),
|
|
|
|
padding=(10, 5),
|
|
|
@ -142,7 +149,7 @@ class ConvStack(nn.Layer):
|
|
|
|
output_height = (feat_size - 1) // 2 + 1
|
|
|
|
output_height = (feat_size - 1) // 2 + 1
|
|
|
|
for i in range(self.num_stacks - 1):
|
|
|
|
for i in range(self.num_stacks - 1):
|
|
|
|
output_height = (output_height - 1) // 2 + 1
|
|
|
|
output_height = (output_height - 1) // 2 + 1
|
|
|
|
self.output_height = output_height
|
|
|
|
self.output_height = out_channel * output_height
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -239,13 +246,14 @@ class GRUCellShare(nn.RNNCellBase):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
input_size,
|
|
|
|
hidden_size,
|
|
|
|
hidden_size,
|
|
|
|
weight_ih_attr=None,
|
|
|
|
weight_ih_attr=None,
|
|
|
|
weight_hh_attr=None,
|
|
|
|
weight_hh_attr=None,
|
|
|
|
bias_ih_attr=None,
|
|
|
|
bias_ih_attr=None,
|
|
|
|
bias_hh_attr=None,
|
|
|
|
bias_hh_attr=None,
|
|
|
|
name=None):
|
|
|
|
name=None):
|
|
|
|
super(GRUCell, self).__init__()
|
|
|
|
super().__init__()
|
|
|
|
std = 1.0 / math.sqrt(hidden_size)
|
|
|
|
std = 1.0 / math.sqrt(hidden_size)
|
|
|
|
self.weight_hh = self.create_parameter(
|
|
|
|
self.weight_hh = self.create_parameter(
|
|
|
|
(3 * hidden_size, hidden_size),
|
|
|
|
(3 * hidden_size, hidden_size),
|
|
|
@ -316,7 +324,6 @@ class BiRNNWithBN(nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, i_size, h_size, share_weights):
|
|
|
|
def __init__(self, i_size, h_size, share_weights):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.share_weights = share_weights
|
|
|
|
self.share_weights = share_weights
|
|
|
|
self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
|
|
|
|
self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
|
|
|
|
if self.share_weights:
|
|
|
|
if self.share_weights:
|
|
|
@ -344,7 +351,7 @@ class BiRNNWithBN(nn.Layer):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
fw_x = self.fw_bn(self.fw_fc(x))
|
|
|
|
fw_x = self.fw_bn(self.fw_fc(x))
|
|
|
|
bw_x = self.bw_bn(self.bw_bn(x))
|
|
|
|
bw_x = self.bw_bn(self.bw_fc(x))
|
|
|
|
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
|
|
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
|
|
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
|
|
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
|
|
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
|
|
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
|
@ -367,16 +374,16 @@ class BiGRUWithBN(nn.Layer):
|
|
|
|
:rtype: Variable
|
|
|
|
:rtype: Variable
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, i_size, act):
|
|
|
|
def __init__(self, i_size, h_size, act):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
hidden_size = i_size * 3
|
|
|
|
hidden_size = h_size * 3
|
|
|
|
self.fw_fc = nn.Linear(i_size, hidden_size)
|
|
|
|
self.fw_fc = nn.Linear(i_size, hidden_size)
|
|
|
|
self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
|
|
self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
|
|
self.bw_fc = nn.Linear(i_size, hidden_size)
|
|
|
|
self.bw_fc = nn.Linear(i_size, hidden_size)
|
|
|
|
self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
|
|
self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
|
|
|
|
|
|
|
|
|
|
self.fw_cell = GRUCellShare(hidden_size)
|
|
|
|
self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
|
|
|
|
self.bw_cell = GRUCellShare(hidden_size)
|
|
|
|
self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
self.fw_rnn = nn.RNN(
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
|
self.bw_rnn = nn.RNN(
|
|
|
@ -385,7 +392,7 @@ class BiGRUWithBN(nn.Layer):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
# x, shape [B, T, D]
|
|
|
|
fw_x = self.fw_bn(self.fw_fc(x))
|
|
|
|
fw_x = self.fw_bn(self.fw_fc(x))
|
|
|
|
bw_x = self.bw_bn(self.bw_bn(x))
|
|
|
|
bw_x = self.bw_bn(self.bw_fc(x))
|
|
|
|
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
|
|
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
|
|
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
|
|
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
|
|
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
|
|
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
|
@ -412,17 +419,20 @@ class RNNStack(nn.Layer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):
|
|
|
|
def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
self.rnn_stacks = nn.LayerList()
|
|
|
|
self.rnn_stacks = nn.LayerList()
|
|
|
|
for i in range(num_stacks):
|
|
|
|
for i in range(num_stacks):
|
|
|
|
if use_gru:
|
|
|
|
if use_gru:
|
|
|
|
#default:GRU using tanh
|
|
|
|
#default:GRU using tanh
|
|
|
|
self.rnn_stacks.append(BiGRUWithBN(size=i_size, act="relu"))
|
|
|
|
self.rnn_stacks.append(
|
|
|
|
|
|
|
|
BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu"))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.rnn_stacks.append(
|
|
|
|
self.rnn_stacks.append(
|
|
|
|
BiRNNWithBN(
|
|
|
|
BiRNNWithBN(
|
|
|
|
i_size=i_size,
|
|
|
|
i_size=i_size,
|
|
|
|
size=h_size,
|
|
|
|
h_size=h_size,
|
|
|
|
share_weights=share_rnn_weights, ))
|
|
|
|
share_weights=share_rnn_weights))
|
|
|
|
|
|
|
|
i_size = h_size * 2
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
def forward(self, x, x_len):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -471,30 +481,25 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
num_rnn_layers=3,
|
|
|
|
num_rnn_layers=3,
|
|
|
|
rnn_size=256,
|
|
|
|
rnn_size=256,
|
|
|
|
use_gru=False,
|
|
|
|
use_gru=False,
|
|
|
|
share_rnn_weight=True):
|
|
|
|
share_rnn_weights=True):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.feat_size = feat_size # 161 for linear
|
|
|
|
self.feat_size = feat_size # 161 for linear
|
|
|
|
self.dict_size = dict_size
|
|
|
|
self.dict_size = dict_size
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = ConvStack(num_conv_layers)
|
|
|
|
self.conv = ConvStack(feat_size, num_conv_layers)
|
|
|
|
|
|
|
|
|
|
|
|
i_size = self.conv.output_height(feat_size) # H after conv stack
|
|
|
|
i_size = self.conv.output_height # H after conv stack
|
|
|
|
self.rnn = RNNStack(
|
|
|
|
self.rnn = RNNStack(
|
|
|
|
i_size=i_size,
|
|
|
|
i_size=i_size,
|
|
|
|
h_size=rnn_size,
|
|
|
|
h_size=rnn_size,
|
|
|
|
num_stacks=num_rnn_layers,
|
|
|
|
num_stacks=num_rnn_layers,
|
|
|
|
use_gru=use_gru,
|
|
|
|
use_gru=use_gru,
|
|
|
|
share_rnn_weights=share_rnn_weights, )
|
|
|
|
share_rnn_weights=share_rnn_weights)
|
|
|
|
self.fc = nn.Linaer(rnn_size * 2, dict_size + 1)
|
|
|
|
self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
|
|
|
|
|
|
|
|
|
|
|
|
self.loss = nn.CTCLoss(blank=dict_size, reduction='none')
|
|
|
|
self.loss = nn.CTCLoss(blank=dict_size, reduction='none')
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, audio, text, audio_len, text_len):
|
|
|
|
def predict(self, audio, audio_len):
|
|
|
|
"""
|
|
|
|
|
|
|
|
audio: shape [B, D, T]
|
|
|
|
|
|
|
|
text: shape [B, T]
|
|
|
|
|
|
|
|
audio_len: shape [B]
|
|
|
|
|
|
|
|
text_len: shape [B]
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# [B, D, T] -> [B, C=1, D, T]
|
|
|
|
# [B, D, T] -> [B, C=1, D, T]
|
|
|
|
audio = audio.unsqueeze(1)
|
|
|
|
audio = audio.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
@ -504,7 +509,7 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
# convert data from convolution feature map to sequence of vectors
|
|
|
|
# convert data from convolution feature map to sequence of vectors
|
|
|
|
B, C, D, T = paddle.shape(x)
|
|
|
|
B, C, D, T = paddle.shape(x)
|
|
|
|
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
|
|
|
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
|
|
|
x = x.reshape([0, -1, C * D]) #[B, T, C*D]
|
|
|
|
x = x.reshape([B, T, C * D]) #[B, T, C*D]
|
|
|
|
|
|
|
|
|
|
|
|
# remove padding part
|
|
|
|
# remove padding part
|
|
|
|
x, audio_len = self.rnn(x, audio_len) #[B, T, D]
|
|
|
|
x, audio_len = self.rnn(x, audio_len) #[B, T, D]
|
|
|
@ -512,14 +517,31 @@ class DeepSpeech2(nn.Layer):
|
|
|
|
logits = self.fc(x) #[B, T, V + 1]
|
|
|
|
logits = self.fc(x) #[B, T, V + 1]
|
|
|
|
|
|
|
|
|
|
|
|
#ctcdecoder need probs, not log_probs
|
|
|
|
#ctcdecoder need probs, not log_probs
|
|
|
|
probs = F.log_softmax(logits)
|
|
|
|
probs = F.softmax(logits)
|
|
|
|
|
|
|
|
|
|
|
|
if not text:
|
|
|
|
return logits, probs
|
|
|
|
return probs, None
|
|
|
|
|
|
|
|
else:
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
|
|
|
def infer(self, audio, audio_len):
|
|
|
|
|
|
|
|
_, probs = self.predict(audio, audio_len)
|
|
|
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, audio, text, audio_len, text_len):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
audio: shape [B, D, T]
|
|
|
|
|
|
|
|
text: shape [B, T]
|
|
|
|
|
|
|
|
audio_len: shape [B]
|
|
|
|
|
|
|
|
text_len: shape [B]
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
logits, probs = self.predict(audio, audio_len)
|
|
|
|
# warp-ctc do softmax on activations
|
|
|
|
# warp-ctc do softmax on activations
|
|
|
|
# warp-ctc need activation with shape [T, B, V + 1]
|
|
|
|
# warp-ctc need activation with shape [T, B, V + 1]
|
|
|
|
logits = logits.transpose([1, 0, 2])
|
|
|
|
logits = logits.transpose([1, 0, 2])
|
|
|
|
|
|
|
|
print(logits.shape)
|
|
|
|
|
|
|
|
print(text.shape)
|
|
|
|
|
|
|
|
print(audio_len.shape)
|
|
|
|
|
|
|
|
print(text_len.shape)
|
|
|
|
ctc_loss = self.loss(logits, text, audio_len, text_len)
|
|
|
|
ctc_loss = self.loss(logits, text, audio_len, text_len)
|
|
|
|
ctc_loss = paddle.reduce_sum(ctc_loss)
|
|
|
|
ctc_loss /= text_len # norm_by_times
|
|
|
|
|
|
|
|
ctc_loss = ctc_loss.sum()
|
|
|
|
return probs, ctc_loss
|
|
|
|
return probs, ctc_loss
|
|
|
|