diff --git a/lm/lm_scorer.py b/lm/lm_scorer.py index 1c029e97f..de41754f9 100644 --- a/lm/lm_scorer.py +++ b/lm/lm_scorer.py @@ -42,6 +42,11 @@ class LmScorer(object): words = sentence.strip().split(' ') return len(words) + # reset alpha and beta + def reset_params(self, alpha, beta): + self._alpha = alpha + self._beta = beta + # execute evaluation def __call__(self, sentence, log=False): """Evaluation function, gathering all the different scores diff --git a/tests/test_decoders.py b/tests/test_decoders.py index a5e19b08b..99d8a8289 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -76,7 +76,7 @@ class TestDecoders(unittest.TestCase): blank_id=len(self.vocab_list)) self.assertEqual(beam_result[0][1], self.beam_search_result[1]) - def test_beam_search_nproc_decoder(self): + def test_beam_search_decoder_batch(self): beam_results = ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, diff --git a/tune.py b/tune.py index 9cea66b90..e26bc45ce 100644 --- a/tune.py +++ b/tune.py @@ -12,6 +12,7 @@ from model import deep_speech2 from decoder import * from lm.lm_scorer import LmScorer from error_rate import wer +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -180,10 +181,13 @@ def tune(): params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] + ext_scorer = LmScorer(args.alpha_from, args.beta_from, + args.language_model_path) ## tune parameters in loop - for (alpha, beta) in params_grid: + for alpha, beta in params_grid: wer_sum, wer_counter = 0, 0 - ext_scorer = LmScorer(alpha, beta, args.language_model_path) + # reset scorer + ext_scorer.reset_params(alpha, beta) # beam search using multiple processes beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split,