|
|
|
@ -410,30 +410,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
|
|
|
|
|
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
|
|
|
|
|
if self.args.model_type == "online":
|
|
|
|
|
output_probs_branch, output_lens_branch = self.static_forward_online(
|
|
|
|
|
audio, audio_len)
|
|
|
|
|
output_probs, output_lens = self.static_forward_online(audio,
|
|
|
|
|
audio_len)
|
|
|
|
|
elif self.args.model_type == "offline":
|
|
|
|
|
output_probs_branch, output_lens_branch = self.static_forward_offline(
|
|
|
|
|
audio, audio_len)
|
|
|
|
|
output_probs, output_lens = self.static_forward_offline(audio,
|
|
|
|
|
audio_len)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong model type")
|
|
|
|
|
|
|
|
|
|
self.predictor.clear_intermediate_tensor()
|
|
|
|
|
self.predictor.try_shrink_memory()
|
|
|
|
|
|
|
|
|
|
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
|
|
|
|
|
vocab_list, cfg.decoding_method)
|
|
|
|
|
|
|
|
|
|
result_transcripts = self.model.decoder.decode_probs(
|
|
|
|
|
output_probs_branch.numpy(), output_lens_branch, vocab_list,
|
|
|
|
|
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
|
|
|
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
|
|
|
|
cfg.num_proc_bsearch)
|
|
|
|
|
output_probs, output_lens, vocab_list, cfg.decoding_method,
|
|
|
|
|
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
|
|
|
|
|
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
return result_transcripts
|
|
|
|
|
|
|
|
|
|
def static_forward_online(self, audio, audio_len):
|
|
|
|
|
def static_forward_online(self, audio, audio_len,
|
|
|
|
|
decoder_chunk_size: int=1):
|
|
|
|
|
"""
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
audio (Tensor): shape[B, T, D]
|
|
|
|
|
audio_len (Tensor): shape[B]
|
|
|
|
|
decoder_chunk_size(int)
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
output_probs(numpy.array): shape[B, T, vocab_size]
|
|
|
|
|
output_lens(numpy.array): shape[B]
|
|
|
|
|
"""
|
|
|
|
|
output_probs_list = []
|
|
|
|
|
output_lens_list = []
|
|
|
|
|
decoder_chunk_size = 1
|
|
|
|
|
subsampling_rate = self.model.encoder.conv.subsampling_rate
|
|
|
|
|
receptive_field_length = self.model.encoder.conv.receptive_field_length
|
|
|
|
|
chunk_stride = subsampling_rate * decoder_chunk_size
|
|
|
|
@ -441,41 +453,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
) * subsampling_rate + receptive_field_length
|
|
|
|
|
|
|
|
|
|
x_batch = audio.numpy()
|
|
|
|
|
batch_size = x_batch.shape[0]
|
|
|
|
|
batch_size, Tmax, x_dim = x_batch.shape
|
|
|
|
|
x_len_batch = audio_len.numpy().astype(np.int64)
|
|
|
|
|
max_len_batch = x_batch.shape[1]
|
|
|
|
|
batch_padding_len = chunk_stride - (
|
|
|
|
|
max_len_batch - chunk_size
|
|
|
|
|
|
|
|
|
|
padding_len_batch = chunk_stride - (
|
|
|
|
|
Tmax - chunk_size
|
|
|
|
|
) % chunk_stride # The length of padding for the batch
|
|
|
|
|
x_list = np.split(x_batch, batch_size, axis=0)
|
|
|
|
|
x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0)
|
|
|
|
|
x_len_list = np.split(x_len_batch, batch_size, axis=0)
|
|
|
|
|
|
|
|
|
|
for x, x_len in zip(x_list, x_len_list):
|
|
|
|
|
self.autolog.times.start()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
assert (chunk_size <= x_len[0])
|
|
|
|
|
x_len = x_len[0]
|
|
|
|
|
assert (chunk_size <= x_len)
|
|
|
|
|
|
|
|
|
|
eouts_chunk_list = []
|
|
|
|
|
eouts_chunk_lens_list = []
|
|
|
|
|
if (x_len - chunk_size) % chunk_stride != 0:
|
|
|
|
|
padding_len_x = chunk_stride - (x_len - chunk_size
|
|
|
|
|
) % chunk_stride
|
|
|
|
|
else:
|
|
|
|
|
padding_len_x = 0
|
|
|
|
|
|
|
|
|
|
padding_len_x = chunk_stride - (x_len[0] - chunk_size
|
|
|
|
|
) % chunk_stride
|
|
|
|
|
padding = np.zeros(
|
|
|
|
|
(x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32)
|
|
|
|
|
(x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
|
|
|
|
|
padded_x = np.concatenate([x, padding], axis=1)
|
|
|
|
|
|
|
|
|
|
num_chunk = (x_len[0] + padding_len_x - chunk_size
|
|
|
|
|
) / chunk_stride + 1
|
|
|
|
|
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
|
|
|
|
|
num_chunk = int(num_chunk)
|
|
|
|
|
|
|
|
|
|
chunk_state_h_box = np.zeros(
|
|
|
|
|
(self.config.model.num_rnn_layers, 1,
|
|
|
|
|
self.config.model.rnn_layer_size),
|
|
|
|
|
dtype=np.float32)
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
chunk_state_c_box = np.zeros(
|
|
|
|
|
(self.config.model.num_rnn_layers, 1,
|
|
|
|
|
self.config.model.rnn_layer_size),
|
|
|
|
|
dtype=np.float32)
|
|
|
|
|
dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
input_names = self.predictor.get_input_names()
|
|
|
|
|
audio_handle = self.predictor.get_input_handle(input_names[0])
|
|
|
|
@ -489,16 +502,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
start = i * chunk_stride
|
|
|
|
|
end = start + chunk_size
|
|
|
|
|
x_chunk = padded_x[:, start:end, :]
|
|
|
|
|
x_len_left = np.where(x_len - i * chunk_stride < 0,
|
|
|
|
|
np.zeros_like(x_len, dtype=np.int64),
|
|
|
|
|
x_len - i * chunk_stride)
|
|
|
|
|
x_chunk_len_tmp = np.ones_like(
|
|
|
|
|
x_len, dtype=np.int64) * chunk_size
|
|
|
|
|
x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp,
|
|
|
|
|
x_len_left, x_chunk_len_tmp)
|
|
|
|
|
if (x_chunk_lens[0] <
|
|
|
|
|
if x_len < i * chunk_stride:
|
|
|
|
|
x_chunk_lens = 0
|
|
|
|
|
else:
|
|
|
|
|
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
|
|
|
|
|
|
|
|
|
|
if (x_chunk_lens <
|
|
|
|
|
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
|
|
|
|
|
break
|
|
|
|
|
x_chunk_lens = np.array([x_chunk_lens])
|
|
|
|
|
audio_handle.reshape(x_chunk.shape)
|
|
|
|
|
audio_handle.copy_from_cpu(x_chunk)
|
|
|
|
|
|
|
|
|
@ -530,11 +542,13 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
probs_chunk_lens_list.append(output_chunk_lens)
|
|
|
|
|
output_probs = np.concatenate(probs_chunk_list, axis=1)
|
|
|
|
|
output_lens = np.sum(probs_chunk_lens_list, axis=0)
|
|
|
|
|
output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[
|
|
|
|
|
vocab_size = output_probs.shape[2]
|
|
|
|
|
output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
|
|
|
|
|
1]
|
|
|
|
|
output_probs_padding = np.zeros(
|
|
|
|
|
(1, output_probs_padding_len, output_probs.shape[2]),
|
|
|
|
|
dtype=np.float32) # The prob padding for a piece of utterance
|
|
|
|
|
(1, output_probs_padding_len, vocab_size),
|
|
|
|
|
dtype=output_probs.
|
|
|
|
|
dtype) # The prob padding for a piece of utterance
|
|
|
|
|
output_probs = np.concatenate(
|
|
|
|
|
[output_probs, output_probs_padding], axis=1)
|
|
|
|
|
output_probs_list.append(output_probs)
|
|
|
|
@ -542,13 +556,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.stamp()
|
|
|
|
|
self.autolog.times.end()
|
|
|
|
|
output_probs_branch = np.concatenate(output_probs_list, axis=0)
|
|
|
|
|
output_lens_branch = np.concatenate(output_lens_list, axis=0)
|
|
|
|
|
output_probs_branch = paddle.to_tensor(output_probs_branch)
|
|
|
|
|
output_lens_branch = paddle.to_tensor(output_lens_branch)
|
|
|
|
|
return output_probs_branch, output_lens_branch
|
|
|
|
|
output_probs = np.concatenate(output_probs_list, axis=0)
|
|
|
|
|
output_lens = np.concatenate(output_lens_list, axis=0)
|
|
|
|
|
return output_probs, output_lens
|
|
|
|
|
|
|
|
|
|
def static_forward_offline(self, audio, audio_len):
|
|
|
|
|
"""
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
audio (Tensor): shape[B, T, D]
|
|
|
|
|
audio_len (Tensor): shape[B]
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
output_probs(numpy.array): shape[B, T, vocab_size]
|
|
|
|
|
output_lens(numpy.array): shape[B]
|
|
|
|
|
"""
|
|
|
|
|
x = audio.numpy()
|
|
|
|
|
x_len = audio_len.numpy().astype(np.int64)
|
|
|
|
|
|
|
|
|
@ -574,9 +597,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
|
|
|
|
|
output_lens_handle = self.predictor.get_output_handle(output_names[1])
|
|
|
|
|
output_probs = output_handle.copy_to_cpu()
|
|
|
|
|
output_lens = output_lens_handle.copy_to_cpu()
|
|
|
|
|
output_probs_branch = paddle.to_tensor(output_probs)
|
|
|
|
|
output_lens_branch = paddle.to_tensor(output_lens)
|
|
|
|
|
return output_probs_branch, output_lens_branch
|
|
|
|
|
return output_probs, output_lens
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
try:
|
|
|
|
|