make the code simple

pull/786/head
huangyuxin 3 years ago
parent 7ab022e1cc
commit 1f050a4d01

@ -270,24 +270,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
@ -308,6 +293,26 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
self.autolog.times.start()
self.autolog.times.stamp()
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
@ -403,21 +408,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
super().__init__(config, args)
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":
output_probs_branch, output_lens_branch = self.static_forward_online(
audio, audio_len)
@ -437,31 +428,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
target_transcripts = self.ordid2token(texts, texts_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
return result_transcripts
def static_forward_online(self, audio, audio_len):
output_probs_list = []
output_lens_list = []
decoder_chunk_size = 8
decoder_chunk_size = 1
subsampling_rate = self.model.encoder.conv.subsampling_rate
receptive_field_length = self.model.encoder.conv.receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
@ -553,27 +525,27 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
output_chunk_probs = paddle.to_tensor(output_chunk_probs)
output_chunk_lens = paddle.to_tensor(output_chunk_lens)
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
output_probs = paddle.concat(probs_chunk_list, axis=1)
output_lens = paddle.add_n(probs_chunk_lens_list)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
output_probs_padding_len = max_len_batch + batch_padding_len - output_probs.shape[
1]
output_probs_padding = paddle.zeros(
output_probs_padding = np.zeros(
(1, output_probs_padding_len, output_probs.shape[2]),
dtype="float32") # The prob padding for a piece of utterance
output_probs = paddle.concat(
dtype=np.float32) # The prob padding for a piece of utterance
output_probs = np.concatenate(
[output_probs, output_probs_padding], axis=1)
output_probs_list.append(output_probs)
output_lens_list.append(output_lens)
self.autolog.times.stamp()
self.autolog.times.stamp()
self.autolog.times.end()
output_probs_branch = paddle.concat(output_probs_list, axis=0)
output_lens_branch = paddle.concat(output_lens_list, axis=0)
output_probs_branch = np.concatenate(output_probs_list, axis=0)
output_lens_branch = np.concatenate(output_lens_list, axis=0)
output_probs_branch = paddle.to_tensor(output_probs_branch)
output_lens_branch = paddle.to_tensor(output_lens_branch)
return output_probs_branch, output_lens_branch
def static_forward_offline(self, audio, audio_len):

Loading…
Cancel
Save