From 314886d4c5612eade0cf8bf6d28edb6824215be0 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 19 Feb 2021 06:46:53 +0000 Subject: [PATCH] test with decoding --- model_utils/model.py | 73 +++++++++++++++++++++++++++++++++++++++--- model_utils/network.py | 18 ++++++++--- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/model_utils/model.py b/model_utils/model.py index a063e4aae..b880510da 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -39,6 +39,8 @@ from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_beam_search_decoder_batch +from utils.error_rate import char_errors, word_errors + class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): @@ -263,24 +265,83 @@ class DeepSpeech2Tester(Trainer): loss = self.criterion(logits, texts, logits_len, texts_len) return loss + def id2token(self, texts, texts_len, vocab_list): + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([vocab_list[i] for i in ids])) + return np.array(trans) + def compute_metrics(self, inputs, outputs): + cfg = self.config.decoding + _, texts, _, texts_len = inputs - logits, _, logits_len = outputs - pass + logits, probs, logits_len = outputs + + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors + + vocab_list = self.test_loader.dataset.vocab_list + target_transcripts = self.id2token(texts, texts_len, vocab_list) + result_transcripts = self.model.decode_probs( + probs.numpy(), + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + + for target, result in zip(target_transcripts, result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) @mp_tools.rank_zero_only @paddle.no_grad() def test(self): self.model.eval() losses = defaultdict(list) + + cfg = self.config + # decoders only accept string encoded in utf-8 + vocab_list = self.test_loader.dataset.vocab_list + self.model.init_decode( + beam_alpha=cfg.decoding.alpha, + beam_beta=cfg.decoding.beta, + lang_model_path=cfg.decoding.lang_model_path, + vocab_list=vocab_list, + decoding_method=cfg.decoding.decoding_method) + + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + for i, batch in enumerate(self.test_loader): audio, text, audio_len, text_len = batch outputs = self.model.predict(audio, audio_len) loss = self.compute_losses(batch, outputs) - metrics = self.compute_metrics(batch, outputs) - losses['test_loss'].append(float(loss)) + metrics = self.compute_metrics(batch, outputs) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + self.logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + # write visual log losses = {k: np.mean(v) for k, v in losses.items()} @@ -289,6 +350,8 @@ class DeepSpeech2Tester(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items()) + msg += ", Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) self.logger.info(msg) def setup(self): @@ -359,7 +422,7 @@ class DeepSpeech2Tester(Trainer): collate_fn = SpeechCollator() self.test_loader = DataLoader( test_dataset, - batch_size=config.data.batch_size, + batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn) diff --git a/model_utils/network.py b/model_utils/network.py index e756996d5..2c310b855 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -646,11 +646,10 @@ class DeepSpeech2(nn.Layer): self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) - @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - _, probs, _ = self.predict(audio, audio_len) + def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path, + beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + """ probs: activation after softmax """ if decoding_method == "ctc_greedy": result_transcripts = self._decode_batch_greedy( probs_split=probs, vocab_list=vocab_list) @@ -668,6 +667,15 @@ class DeepSpeech2(nn.Layer): raise ValueError(f"Not support: {decoding_method}") return result_transcripts + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + _, probs, _ = self.predict(audio, audio_len) + return self.decode_probs( + probs, vocab_list, decoding_method, lang_model_path, beam_alpha, + beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) + class DeepSpeech2Loss(nn.Layer): def __init__(self, vocab_size):