enable resetting params in scorer

pull/2/head
Yibing Liu 8 years ago
parent 90d83bf739
commit 37e98df74d

@ -42,6 +42,11 @@ class LmScorer(object):
words = sentence.strip().split(' ') words = sentence.strip().split(' ')
return len(words) return len(words)
# reset alpha and beta
def reset_params(self, alpha, beta):
self._alpha = alpha
self._beta = beta
# execute evaluation # execute evaluation
def __call__(self, sentence, log=False): def __call__(self, sentence, log=False):
"""Evaluation function, gathering all the different scores """Evaluation function, gathering all the different scores

@ -76,7 +76,7 @@ class TestDecoders(unittest.TestCase):
blank_id=len(self.vocab_list)) blank_id=len(self.vocab_list))
self.assertEqual(beam_result[0][1], self.beam_search_result[1]) self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_nproc_decoder(self): def test_beam_search_decoder_batch(self):
beam_results = ctc_beam_search_decoder_batch( beam_results = ctc_beam_search_decoder_batch(
probs_split=[self.probs_seq1, self.probs_seq2], probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size, beam_size=self.beam_size,

@ -12,6 +12,7 @@ from model import deep_speech2
from decoder import * from decoder import *
from lm.lm_scorer import LmScorer from lm.lm_scorer import LmScorer
from error_rate import wer from error_rate import wer
import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
@ -180,10 +181,13 @@ def tune():
params_grid = [(alpha, beta) for alpha in cand_alphas params_grid = [(alpha, beta) for alpha in cand_alphas
for beta in cand_betas] for beta in cand_betas]
ext_scorer = LmScorer(args.alpha_from, args.beta_from,
args.language_model_path)
## tune parameters in loop ## tune parameters in loop
for (alpha, beta) in params_grid: for alpha, beta in params_grid:
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
ext_scorer = LmScorer(alpha, beta, args.language_model_path) # reset scorer
ext_scorer.reset_params(alpha, beta)
# beam search using multiple processes # beam search using multiple processes
beam_search_results = ctc_beam_search_decoder_batch( beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split, probs_split=probs_split,

Loading…
Cancel
Save