simplify the code

pull/786/head
huangyuxin 3 years ago
parent 1f050a4d01
commit 317ffea5e5

@ -410,30 +410,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online": if self.args.model_type == "online":
output_probs_branch, output_lens_branch = self.static_forward_online( output_probs, output_lens = self.static_forward_online(audio,
audio, audio_len) audio_len)
elif self.args.model_type == "offline": elif self.args.model_type == "offline":
output_probs_branch, output_lens_branch = self.static_forward_offline( output_probs, output_lens = self.static_forward_offline(audio,
audio, audio_len) audio_len)
else: else:
raise Exception("wrong model type") raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor() self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory() self.predictor.try_shrink_memory()
self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
vocab_list, cfg.decoding_method) vocab_list, cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs( result_transcripts = self.model.decoder.decode_probs(
output_probs_branch.numpy(), output_lens_branch, vocab_list, output_probs, output_lens, vocab_list, cfg.decoding_method,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
cfg.num_proc_bsearch)
return result_transcripts return result_transcripts
def static_forward_online(self, audio, audio_len): def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
decoder_chunk_size(int)
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
output_probs_list = [] output_probs_list = []
output_lens_list = [] output_lens_list = []
decoder_chunk_size = 1
subsampling_rate = self.model.encoder.conv.subsampling_rate subsampling_rate = self.model.encoder.conv.subsampling_rate
receptive_field_length = self.model.encoder.conv.receptive_field_length receptive_field_length = self.model.encoder.conv.receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size chunk_stride = subsampling_rate * decoder_chunk_size
@ -441,41 +453,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
) * subsampling_rate + receptive_field_length ) * subsampling_rate + receptive_field_length
x_batch = audio.numpy() x_batch = audio.numpy()
batch_size = x_batch.shape[0] batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64) x_len_batch = audio_len.numpy().astype(np.int64)
max_len_batch = x_batch.shape[1]
batch_padding_len = chunk_stride - ( padding_len_batch = chunk_stride - (
max_len_batch - chunk_size Tmax - chunk_size
) % chunk_stride # The length of padding for the batch ) % chunk_stride # The length of padding for the batch
x_list = np.split(x_batch, batch_size, axis=0) x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) x_len_list = np.split(x_len_batch, batch_size, axis=0)
for x, x_len in zip(x_list, x_len_list): for x, x_len in zip(x_list, x_len_list):
self.autolog.times.start() self.autolog.times.start()
self.autolog.times.stamp() self.autolog.times.stamp()
assert (chunk_size <= x_len[0]) x_len = x_len[0]
assert (chunk_size <= x_len)
eouts_chunk_list = [] if (x_len - chunk_size) % chunk_stride != 0:
eouts_chunk_lens_list = [] padding_len_x = chunk_stride - (x_len - chunk_size
) % chunk_stride
else:
padding_len_x = 0
padding_len_x = chunk_stride - (x_len[0] - chunk_size
) % chunk_stride
padding = np.zeros( padding = np.zeros(
(x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32) (x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
padded_x = np.concatenate([x, padding], axis=1) padded_x = np.concatenate([x, padding], axis=1)
num_chunk = (x_len[0] + padding_len_x - chunk_size num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
) / chunk_stride + 1
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
chunk_state_h_box = np.zeros( chunk_state_h_box = np.zeros(
(self.config.model.num_rnn_layers, 1, (self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size), self.config.model.rnn_layer_size),
dtype=np.float32) dtype=x.dtype)
chunk_state_c_box = np.zeros( chunk_state_c_box = np.zeros(
(self.config.model.num_rnn_layers, 1, (self.config.model.num_rnn_layers, 1,
self.config.model.rnn_layer_size), self.config.model.rnn_layer_size),
dtype=np.float32) dtype=x.dtype)
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0]) audio_handle = self.predictor.get_input_handle(input_names[0])
@ -489,16 +502,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
start = i * chunk_stride start = i * chunk_stride
end = start + chunk_size end = start + chunk_size
x_chunk = padded_x[:, start:end, :] x_chunk = padded_x[:, start:end, :]
x_len_left = np.where(x_len - i * chunk_stride < 0, if x_len < i * chunk_stride:
np.zeros_like(x_len, dtype=np.int64), x_chunk_lens = 0
x_len - i * chunk_stride) else:
x_chunk_len_tmp = np.ones_like( x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
x_len, dtype=np.int64) * chunk_size
x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp, if (x_chunk_lens <
x_len_left, x_chunk_len_tmp)
if (x_chunk_lens[0] <
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
break break
x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape) audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk) audio_handle.copy_from_cpu(x_chunk)
@ -530,11 +542,13 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
probs_chunk_lens_list.append(output_chunk_lens) probs_chunk_lens_list.append(output_chunk_lens)
output_probs = np.concatenate(probs_chunk_list, axis=1) output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0) output_lens = np.sum(probs_chunk_lens_list, axis=0)
output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ vocab_size = output_probs.shape[2]
output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
1] 1]
output_probs_padding = np.zeros( output_probs_padding = np.zeros(
(1, output_probs_padding_len, output_probs.shape[2]), (1, output_probs_padding_len, vocab_size),
dtype=np.float32) # The prob padding for a piece of utterance dtype=output_probs.
dtype) # The prob padding for a piece of utterance
output_probs = np.concatenate( output_probs = np.concatenate(
[output_probs, output_probs_padding], axis=1) [output_probs, output_probs_padding], axis=1)
output_probs_list.append(output_probs) output_probs_list.append(output_probs)
@ -542,13 +556,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.autolog.times.stamp() self.autolog.times.stamp()
self.autolog.times.stamp() self.autolog.times.stamp()
self.autolog.times.end() self.autolog.times.end()
output_probs_branch = np.concatenate(output_probs_list, axis=0) output_probs = np.concatenate(output_probs_list, axis=0)
output_lens_branch = np.concatenate(output_lens_list, axis=0) output_lens = np.concatenate(output_lens_list, axis=0)
output_probs_branch = paddle.to_tensor(output_probs_branch) return output_probs, output_lens
output_lens_branch = paddle.to_tensor(output_lens_branch)
return output_probs_branch, output_lens_branch
def static_forward_offline(self, audio, audio_len): def static_forward_offline(self, audio, audio_len):
"""
Parameters
----------
audio (Tensor): shape[B, T, D]
audio_len (Tensor): shape[B]
Returns
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
"""
x = audio.numpy() x = audio.numpy()
x_len = audio_len.numpy().astype(np.int64) x_len = audio_len.numpy().astype(np.int64)
@ -574,9 +597,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_lens_handle = self.predictor.get_output_handle(output_names[1]) output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_probs = output_handle.copy_to_cpu() output_probs = output_handle.copy_to_cpu()
output_lens = output_lens_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu()
output_probs_branch = paddle.to_tensor(output_probs) return output_probs, output_lens
output_lens_branch = paddle.to_tensor(output_lens)
return output_probs_branch, output_lens_branch
def run_test(self): def run_test(self):
try: try:

Loading…
Cancel
Save