diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f386336ad..74d9b2050 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -270,24 +270,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): vocab_list = self.test_loader.collate_fn.vocab_list target_transcripts = self.ordid2token(texts, texts_len) - self.autolog.times.start() - self.autolog.times.stamp() - result_transcripts = self.model.decode( - audio, - audio_len, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) - self.autolog.times.stamp() - self.autolog.times.stamp() - self.autolog.times.end() + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) @@ -308,6 +293,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate=errors_sum / len_refs, error_rate_type=cfg.error_rate_type) + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + self.autolog.times.start() + self.autolog.times.stamp() + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + return result_transcripts + @mp_tools.rank_zero_only @paddle.no_grad() def test(self): @@ -403,21 +408,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) - def compute_metrics(self, - utts, - audio, - audio_len, - texts, - texts_len, - fout=None): - cfg = self.config.decoding - - errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors - error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - - vocab_list = self.test_loader.collate_fn.vocab_list - + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": output_probs_branch, output_lens_branch = self.static_forward_online( audio, audio_len) @@ -437,31 +428,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - target_transcripts = self.ordid2token(texts, texts_len) - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): - errors, len_ref = errors_func(target, result) - errors_sum += errors - len_refs += len_ref - num_ins += 1 - if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) - logger.info("Current error rate [%s] = %f" % - (cfg.error_rate_type, error_rate_func(target, result))) - - return dict( - errors_sum=errors_sum, - len_refs=len_refs, - num_ins=num_ins, - error_rate=errors_sum / len_refs, - error_rate_type=cfg.error_rate_type) + return result_transcripts def static_forward_online(self, audio, audio_len): output_probs_list = [] output_lens_list = [] - decoder_chunk_size = 8 + decoder_chunk_size = 1 subsampling_rate = self.model.encoder.conv.subsampling_rate receptive_field_length = self.model.encoder.conv.receptive_field_length chunk_stride = subsampling_rate * decoder_chunk_size @@ -553,27 +525,27 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_chunk_lens = output_lens_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() - output_chunk_probs = paddle.to_tensor(output_chunk_probs) - output_chunk_lens = paddle.to_tensor(output_chunk_lens) probs_chunk_list.append(output_chunk_probs) probs_chunk_lens_list.append(output_chunk_lens) - output_probs = paddle.concat(probs_chunk_list, axis=1) - output_lens = paddle.add_n(probs_chunk_lens_list) + output_probs = np.concatenate(probs_chunk_list, axis=1) + output_lens = np.sum(probs_chunk_lens_list, axis=0) output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[ 1] - output_probs_padding = paddle.zeros( + output_probs_padding = np.zeros( (1, output_probs_padding_len, output_probs.shape[2]), - dtype="float32") # The prob padding for a piece of utterance - output_probs = paddle.concat( + dtype=np.float32) # The prob padding for a piece of utterance + output_probs = np.concatenate( [output_probs, output_probs_padding], axis=1) output_probs_list.append(output_probs) output_lens_list.append(output_lens) self.autolog.times.stamp() self.autolog.times.stamp() self.autolog.times.end() - output_probs_branch = paddle.concat(output_probs_list, axis=0) - output_lens_branch = paddle.concat(output_lens_list, axis=0) + output_probs_branch = np.concatenate(output_probs_list, axis=0) + output_lens_branch = np.concatenate(output_lens_list, axis=0) + output_probs_branch = paddle.to_tensor(output_probs_branch) + output_lens_branch = paddle.to_tensor(output_lens_branch) return output_probs_branch, output_lens_branch def static_forward_offline(self, audio, audio_len):