ds2_online alignment, include prob_chunk_forward, prob_chunk_by_chunk_forward

pull/735/head
huangyuxin 3 years ago
parent 18eb2cb5ed
commit 4b5cbe9a12

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False

@ -56,40 +56,28 @@ class CRNNEncoder(nn.Layer):
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
layernorm_size = rnn_size
if use_gru == True:
self.rnn.append(
nn.GRU(
input_size=i_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
for i in range(1, num_rnn_layers):
for i in range(0, num_rnn_layers):
if i == 0:
rnn_input_size = i_size
else:
rnn_input_size = rnn_size
if (use_gru == True):
self.rnn.append(
nn.GRU(
input_size=layernorm_size,
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
else:
self.rnn.append(
nn.LSTM(
input_size=i_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
for i in range(1, num_rnn_layers):
else:
self.rnn.append(
nn.LSTM(
input_size=layernorm_size,
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
fc_input_size = layernorm_size
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
fc_input_size = rnn_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i]))
@ -122,10 +110,7 @@ class CRNNEncoder(nn.Layer):
# remove padding part
init_state = None
final_state_list = []
x, final_state = self.rnn[0](x, init_state, x_lens)
final_state_list.append(final_state)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state, x_lens) #[B, T, D]
final_state_list.append(final_state)
x = self.layernorm_list[i](x)
@ -149,10 +134,7 @@ class CRNNEncoder(nn.Layer):
"""
x, x_lens = self.conv(x, x_lens)
chunk_final_state_list = []
x, final_state = self.rnn[0](x, init_state_list[0], x_lens)
chunk_final_state_list.append(final_state)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
chunk_final_state_list.append(final_state)
@ -177,27 +159,32 @@ class CRNNEncoder(nn.Layer):
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
x_padded = paddle.concat([x, padding], axis=1)
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_init_state_list = [None] * self.num_rnn_layers
chunk_state_list = [None] * self.num_rnn_layers
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = x_padded[:, start:end, :]
x_len_left = x_lens - i * chunk_stride
# end = min(start + chunk_size, max_len)
# if (end - start < receptive_field_length):
# break
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_final_state_list = self.forward_chunk(
x_chunk, x_chunk_lens, chunk_init_state_list)
eouts_chunk, eouts_chunk_lens, chunk_state_list = self.forward_chunk(
x_chunk, x_chunk_lens, chunk_state_list)
chunk_init_state_list = chunk_final_state_list
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
return eouts_chunk_list, eouts_chunk_lens_list, chunk_final_state_list
return eouts_chunk_list, eouts_chunk_lens_list, chunk_state_list
class DeepSpeech2ModelOnline(nn.Layer):
@ -309,6 +296,35 @@ class DeepSpeech2ModelOnline(nn.Layer):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@paddle.no_grad()
def decode_by_chunk(self, eouts_prefix, eouts_len_prefix, chunk_state_list,
audio_chunk, audio_len_chunk, vocab_list,
decoding_method, lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts_chunk, eouts_chunk_len, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_len_chunk, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_len = paddle.add_n([eouts_len_prefix, eouts_chunk_len])
else:
eouts = eouts_chunk
eouts_len = eouts_chunk_len
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes), eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_chunk_by_chunk(self, audio, audio_len, vocab_list,
decoding_method, lang_model_path, beam_alpha,
@ -334,6 +350,13 @@ class DeepSpeech2ModelOnline(nn.Layer):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
"""
decocd_prob,
decode_prob_chunk_by_chunk
decode_prob_by_chunk
is only used for test
"""
@paddle.no_grad()
def decode_prob(self, audio, audio_len):
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len)
@ -341,15 +364,28 @@ class DeepSpeech2ModelOnline(nn.Layer):
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_chunk_by_chunk(self, audio, audio_len):
def decode_prob_chunk_by_chunk(self, audio, audio_len, decoder_chunk_size):
eouts_chunk_list, eouts_chunk_len_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio, audio_len)
audio, audio_len, decoder_chunk_size)
eouts = paddle.concat(eouts_chunk_list, axis=1)
eouts_len = paddle.add_n(eouts_chunk_len_list)
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_by_chunk(self, audio, audio_len, eouts_prefix,
eouts_lens_prefix, chunk_state_list):
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio, audio_len, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_lens = paddle.add_n([eouts_lens_prefix, eouts_chunk_lens])
else:
eouts = eouts_chunk
eouts_lens = eouts_chunk_lens
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_lens, final_state_list
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
@ -420,15 +456,14 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
probs = self.decoder.softmax(eouts)
return probs
def forward_chunk_by_chunk(self, audio, audio_len):
eouts_chunk_list, eouts_chunk_lens_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio_chunk, audio_chunk_len)
eouts = paddle.concat(eouts_chunk_list, axis=1)
def forward_chunk(self, audio_chunk, audio_chunk_lens):
eouts_chunkt, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens)
probs = self.decoder.softmax(eouts)
return probs
def forward(self, eouts_chunk_prefix, eouts_chunk_lens_prefix, audio_chunk,
audio_chunk_len, init_state_list):
audio_chunk_lens, chunk_state_list):
"""export model function
Args:
@ -438,8 +473,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
Returns:
probs: probs after softmax
"""
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder(
audio_chunk, audio_chunk_len, init_state_list)
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens, chunk_state_list)
eouts_chunk_new_prefix = paddle.concat(
[eouts_chunk_prefix, eouts_chunk], axis=1)
eouts_chunk_lens_new_prefix = paddle.add(eouts_chunk_lens_prefix,

@ -25,7 +25,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.batch_size = 2
self.feat_dim = 161
max_len = 64
max_len = 210
# (B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
@ -105,29 +105,116 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate,
receptive_field_length):
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
x_chunk_list = []
x_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
x_chunk_list.append(x_chunk)
x_chunk_lens_list.append(x_chunk_lens)
return x_chunk_list, x_chunk_lens_list
def test_ds2_6(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
num_rnn_layers=1,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
use_gru=True)
model.eval()
probs, eouts, eouts_len, final_state_list = model.decode_prob(
paddle.device.set_device("cpu")
de_ch_size = 9
audio_chunk_list, audio_chunk_lens_list = self.split_into_chunk(
self.audio, self.audio_len, de_ch_size,
model.encoder.conv.subsampling_rate,
model.encoder.conv.receptive_field_length)
eouts_prefix = None
eouts_lens_prefix = None
chunk_state_list = [None] * model.encoder.num_rnn_layers
for i, audio_chunk in enumerate(audio_chunk_list):
audio_chunk_lens = audio_chunk_lens_list[i]
probs_pre_chunks, eouts_prefix, eouts_lens_prefix, chunk_state_list = model.decode_prob_by_chunk(
audio_chunk, audio_chunk_lens, eouts_prefix, eouts_lens_prefix,
chunk_state_list)
# print (i, probs_pre_chunks.shape)
probs, eouts, eouts_lens, final_state_list = model.decode_prob(
self.audio, self.audio_len)
probs_chk, eouts_chk, eouts_len_chk, final_state_list_chk = model.decode_prob_chunk_by_chunk(
decode_max_len = probs.shape[1]
probs_pre_chunks = probs_pre_chunks[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(probs, probs_pre_chunks), True)
def test_ds2_7(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 9
probs, eouts, eouts_lens, final_state_list = model.decode_prob(
self.audio, self.audio_len)
for i in range(len(final_state_list)):
for j in range(2):
self.assertEqual(
np.sum(
np.abs(final_state_list[i][j].numpy() -
final_state_list_chk[i][j].numpy())), 0)
probs_by_chk, eouts_by_chk, eouts_lens_by_chk, final_state_list_by_chk = model.decode_prob_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
decode_max_len = probs.shape[1]
probs_by_chk = probs_by_chk[:, :decode_max_len, :]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(
paddle.sum(
paddle.abs(paddle.subtract(eouts_lens, eouts_lens_by_chk))), 0)
self.assertEqual(
paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))), 0)
self.assertEqual(
paddle.sum(
paddle.abs(paddle.subtract(probs, probs_by_chk))).numpy(), 0)
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(paddle.allclose(probs_by_chk, probs), True)
"""
print ("conv_x", conv_x)
print ("conv_x_by_chk", conv_x_by_chk)
print ("final_state_list", final_state_list)
#print ("final_state_list_by_chk", final_state_list_by_chk)
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))))
print (paddle.allclose(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))))
print (paddle.allclose(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))))
print (paddle.allclose(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:]))
"""
if __name__ == '__main__':

Loading…
Cancel
Save