diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 74d9b205..0e0e83c0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -410,30 +410,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": - output_probs_branch, output_lens_branch = self.static_forward_online( - audio, audio_len) + output_probs, output_lens = self.static_forward_online(audio, + audio_len) elif self.args.model_type == "offline": - output_probs_branch, output_lens_branch = self.static_forward_offline( - audio, audio_len) + output_probs, output_lens = self.static_forward_offline(audio, + audio_len) else: raise Exception("wrong model type") + self.predictor.clear_intermediate_tensor() self.predictor.try_shrink_memory() + self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, vocab_list, cfg.decoding_method) result_transcripts = self.model.decoder.decode_probs( - output_probs_branch.numpy(), output_lens_branch, vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) + output_probs, output_lens, vocab_list, cfg.decoding_method, + cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, + cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) return result_transcripts - def static_forward_online(self, audio, audio_len): + def static_forward_online(self, audio, audio_len, + decoder_chunk_size: int=1): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + decoder_chunk_size(int) + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] + """ output_probs_list = [] output_lens_list = [] - decoder_chunk_size = 1 subsampling_rate = self.model.encoder.conv.subsampling_rate receptive_field_length = self.model.encoder.conv.receptive_field_length chunk_stride = subsampling_rate * decoder_chunk_size @@ -441,41 +453,42 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): ) * subsampling_rate + receptive_field_length x_batch = audio.numpy() - batch_size = x_batch.shape[0] + batch_size, Tmax, x_dim = x_batch.shape x_len_batch = audio_len.numpy().astype(np.int64) - max_len_batch = x_batch.shape[1] - batch_padding_len = chunk_stride - ( - max_len_batch - chunk_size + + padding_len_batch = chunk_stride - ( + Tmax - chunk_size ) % chunk_stride # The length of padding for the batch x_list = np.split(x_batch, batch_size, axis=0) - x_len_list = np.split(x_len_batch, x_batch.shape[0], axis=0) + x_len_list = np.split(x_len_batch, batch_size, axis=0) for x, x_len in zip(x_list, x_len_list): self.autolog.times.start() self.autolog.times.stamp() - assert (chunk_size <= x_len[0]) + x_len = x_len[0] + assert (chunk_size <= x_len) - eouts_chunk_list = [] - eouts_chunk_lens_list = [] + if (x_len - chunk_size) % chunk_stride != 0: + padding_len_x = chunk_stride - (x_len - chunk_size + ) % chunk_stride + else: + padding_len_x = 0 - padding_len_x = chunk_stride - (x_len[0] - chunk_size - ) % chunk_stride padding = np.zeros( - (x.shape[0], padding_len_x, x.shape[2]), dtype=np.float32) + (x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype) padded_x = np.concatenate([x, padding], axis=1) - num_chunk = (x_len[0] + padding_len_x - chunk_size - ) / chunk_stride + 1 + num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 num_chunk = int(num_chunk) chunk_state_h_box = np.zeros( (self.config.model.num_rnn_layers, 1, self.config.model.rnn_layer_size), - dtype=np.float32) + dtype=x.dtype) chunk_state_c_box = np.zeros( (self.config.model.num_rnn_layers, 1, self.config.model.rnn_layer_size), - dtype=np.float32) + dtype=x.dtype) input_names = self.predictor.get_input_names() audio_handle = self.predictor.get_input_handle(input_names[0]) @@ -489,16 +502,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): start = i * chunk_stride end = start + chunk_size x_chunk = padded_x[:, start:end, :] - x_len_left = np.where(x_len - i * chunk_stride < 0, - np.zeros_like(x_len, dtype=np.int64), - x_len - i * chunk_stride) - x_chunk_len_tmp = np.ones_like( - x_len, dtype=np.int64) * chunk_size - x_chunk_lens = np.where(x_len_left < x_chunk_len_tmp, - x_len_left, x_chunk_len_tmp) - if (x_chunk_lens[0] < + if x_len < i * chunk_stride: + x_chunk_lens = 0 + else: + x_chunk_lens = min(x_len - i * chunk_stride, chunk_size) + + if (x_chunk_lens < receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob break + x_chunk_lens = np.array([x_chunk_lens]) audio_handle.reshape(x_chunk.shape) audio_handle.copy_from_cpu(x_chunk) @@ -530,11 +542,13 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): probs_chunk_lens_list.append(output_chunk_lens) output_probs = np.concatenate(probs_chunk_list, axis=1) output_lens = np.sum(probs_chunk_lens_list, axis=0) - output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ + vocab_size = output_probs.shape[2] + output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[ 1] output_probs_padding = np.zeros( - (1, output_probs_padding_len, output_probs.shape[2]), - dtype=np.float32) # The prob padding for a piece of utterance + (1, output_probs_padding_len, vocab_size), + dtype=output_probs. + dtype) # The prob padding for a piece of utterance output_probs = np.concatenate( [output_probs, output_probs_padding], axis=1) output_probs_list.append(output_probs) @@ -542,13 +556,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.autolog.times.stamp() self.autolog.times.stamp() self.autolog.times.end() - output_probs_branch = np.concatenate(output_probs_list, axis=0) - output_lens_branch = np.concatenate(output_lens_list, axis=0) - output_probs_branch = paddle.to_tensor(output_probs_branch) - output_lens_branch = paddle.to_tensor(output_lens_branch) - return output_probs_branch, output_lens_branch + output_probs = np.concatenate(output_probs_list, axis=0) + output_lens = np.concatenate(output_lens_list, axis=0) + return output_probs, output_lens def static_forward_offline(self, audio, audio_len): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] + """ x = audio.numpy() x_len = audio_len.numpy().astype(np.int64) @@ -574,9 +597,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_lens_handle = self.predictor.get_output_handle(output_names[1]) output_probs = output_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu() - output_probs_branch = paddle.to_tensor(output_probs) - output_lens_branch = paddle.to_tensor(output_lens) - return output_probs_branch, output_lens_branch + return output_probs, output_lens def run_test(self): try: