final refining on old data provider: enable pruning & add evaluation & code cleanup

pull/2/head
Yibing Liu 7 years ago
parent a633eb9cc6
commit ff01d048d3

@ -5,7 +5,6 @@
import os import os
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
import copy
import kenlm import kenlm
import multiprocessing import multiprocessing
@ -73,11 +72,25 @@ class Scorer(object):
return len(words) return len(words)
# execute evaluation # execute evaluation
def __call__(self, sentence): def __call__(self, sentence, log=False):
"""
Evaluation function
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence) lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence) word_cnt = self.word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \ score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta) * np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score return score
@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None,
nproc=False): nproc=False):
''' '''
Beam search decoder for CTC-trained network, using beam search with width Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in beam_size to find many paths to one label, return beam_size labels in
the order of probabilities. The implementation is based on Prefix Beam the descending order of probabilities. The implementation is based on Prefix
Search(https://arxiv.org/abs/1408.2873), and the unclear part is Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified. redesigned, need to be verified.
:param probs_seq: 2-D list with length num_time_steps, each element :param probs_seq: 2-D list with length num_time_steps, each element
@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. and language model.
:type external_scoring_function: function :type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param nproc: Whether the decoder used in multiprocesses. :param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool :type nproc: bool
:return: Decoding log probability and result string. :return: Decoding log probabilities and result sentences in descending order.
:rtype: list :rtype: list
''' '''
# dimension check # dimension check
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary") raise ValueError("probs dimension mismatched with vocabulary")
num_time_steps = len(probs_seq) num_time_steps = len(probs_seq)
# blank_id check # blank_id check
@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
## extend prefix in loop ## extend prefix in loop
for time_step in range(num_time_steps): for time_step in xrange(num_time_steps):
# the set containing candidate prefixes # the set containing candidate prefixes
prefix_set_next = {} prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {} probs_b_cur, probs_nb_cur = {}, {}
for l in prefix_set_prev:
prob = probs_seq[time_step] prob = probs_seq[time_step]
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
cutoff_len = len(prob_idx)
#If pruning is enabled
if (cutoff_prob < 1.0):
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0
cum_prob = 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
if not prefix_set_next.has_key(l): if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering vocabulary # extend prefix by travering prob_idx
for c in range(0, probs_dim): for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]
if c == blank_id: if c == blank_id:
probs_b_cur[l] += prob[c] * ( probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
else: else:
last_char = l[-1] last_char = l[-1]
@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if new_char == last_char: if new_char == last_char:
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob[c] * probs_nb_prev[l] probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ': elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1): if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0 score = 1.0
else: else:
prefix = l[1:] prefix = l[1:]
score = ext_scoring_func(prefix) score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob[c] * ( probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
else: else:
probs_nb_cur[l_plus] += prob[c] * ( probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l]) probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next # add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[ prefix_set_next[l_plus] = probs_nb_cur[
@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None,
num_processes=None): num_processes=None):
''' '''
@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. and language model.
:type external_scoring_function: function :type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the :param num_processes: Number of processes, default None, equal to the
number of CPUs. number of CPUs.
:type num_processes: int :type num_processes: int
:return: Decoding log probability and result string. :return: Decoding log probabilities and result sentences in descending order.
:rtype: list :rtype: list
''' '''
@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool = multiprocessing.Pool(processes=num_processes) pool = multiprocessing.Pool(processes=num_processes)
results = [] results = []
for i, probs_list in enumerate(probs_split): for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, None, nproc) args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args)) results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close() pool.close()

@ -0,0 +1,214 @@
"""
Evaluation for a simplifed version of Baidu DeepSpeech2 model.
"""
import paddle.v2 as paddle
import distutils.util
import argparse
import gzip
from audio_data_utils import DataGenerator
from model import deep_speech2
from decoder import *
from error_rate import wer
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 evaluation.')
parser.add_argument(
"--num_samples",
default=100,
type=int,
help="Number of samples for evaluation. (default: %(default)s)")
parser.add_argument(
"--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
type=str,
help="Method for ctc decoding, best_path, "
"beam_search or beam_search_nproc. (default: %(default)s)")
parser.add_argument(
"--language_model_path",
default="./data/1Billion.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
"--alpha",
default=0.26,
type=float,
help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument(
"--beta",
default=0.1,
type=float,
help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
parser.add_argument(
"--beam_size",
default=500,
type=int,
help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-clean',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='./params.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()
def evaluate():
"""
Evaluate on whole test data for DeepSpeech2.
"""
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
# create network config
dict_size = data_generator.vocabulary_size()
vocab_list = data_generator.vocabulary_list()
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
output_probs = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
is_inference=True)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_filepath))
# prepare infer data
feeding = data_generator.data_name_feeding()
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
# define inferer
inferer = paddle.inference.Inference(
output_layer=output_probs, parameters=parameters)
# initialize external scorer for beam search decoding
if args.decode_method == 'beam_search' or \
args.decode_method == 'beam_search_nproc':
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
wer_counter, wer_sum = 0, 0.0
for infer_data in test_batch_reader():
# run inference
infer_results = inferer.infer(input=infer_data)
num_steps = len(infer_results) / len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
]
# decode and print
# best path decode
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
output_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path")
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
wer_sum += wer(target_transcription, output_transcription)
wer_counter += 1
# beam search decode in single process
elif args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=vocab_list,
beam_size=args.beam_size,
blank_id=len(vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc":
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=args.beam_size,
blank_id=len(vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
else:
raise ValueError("Decoding method [%s] is not supported." % method)
print("Cur WER = %f" % (wer_sum / wer_counter))
print("Final WER = %f" % (wer_sum / wer_counter))
def main():
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
evaluate()
if __name__ == '__main__':
main()

@ -9,14 +9,14 @@ import gzip
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
import kenlm
from error_rate import wer from error_rate import wer
import time
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.') description='Simplified version of DeepSpeech2 inference.')
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
default=10, default=100,
type=int, type=int,
help="Number of samples for inference. (default: %(default)s)") help="Number of samples for inference. (default: %(default)s)")
parser.add_argument( parser.add_argument(
@ -46,7 +46,7 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-clean', default='data/manifest.libri.test-100sample',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
@ -63,11 +63,13 @@ parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search_nproc', default='beam_search_nproc',
type=str, type=str,
help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" help="Method for ctc decoding:"
) " best_path,"
" beam_search, "
" or beam_search_nproc. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=50, default=500,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument( parser.add_argument(
@ -82,14 +84,20 @@ parser.add_argument(
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--alpha", "--alpha",
default=0.0, default=0.26,
type=float, type=float,
help="Parameter associated with language model. (default: %(default)f)") help="Parameter associated with language model. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta", "--beta",
default=0.0, default=0.1,
type=float, type=float,
help="Parameter associated with word count. (default: %(default)f)") help="Parameter associated with word count. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float,
help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
@ -154,6 +162,7 @@ def infer():
## decode and print ## decode and print
# best path decode # best path decode
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
total_time = 0.0
if args.decode_method == "best_path": if args.decode_method == "best_path":
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
target_transcription = ''.join( target_transcription = ''.join(
@ -177,11 +186,12 @@ def infer():
probs_seq=probs, probs_seq=probs,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in xrange(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
@ -190,21 +200,21 @@ def infer():
wer_counter += 1 wer_counter += 1
print("cur wer = %f , average wer = %f" % print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter)) (wer_cur, wer_sum / wer_counter))
# beam search using multiple processes
elif args.decode_method == "beam_search_nproc": elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc( beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in xrange(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))

@ -1,5 +1,5 @@
""" """
Tune parameters for beam search decoder in Deep Speech 2. Parameters tuning for beam search decoder in Deep Speech 2.
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
@ -12,7 +12,7 @@ from decoder import *
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Parameters tuning script for ctc beam search decoder in Deep Speech 2.' description='Parameters tuning for ctc beam search decoder in Deep Speech 2.'
) )
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
@ -82,34 +82,40 @@ parser.add_argument(
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--alpha_from", "--alpha_from",
default=0.0, default=0.1,
type=float, type=float,
help="Where alpha starts from, <= alpha_to. (default: %(default)f)") help="Where alpha starts from. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--alpha_stride", "--num_alphas",
default=0.001, default=14,
type=float, type=int,
help="Step length for varying alpha. (default: %(default)f)") help="Number of candidate alphas. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--alpha_to", "--alpha_to",
default=0.01, default=0.36,
type=float, type=float,
help="Where alpha ends with, >= alpha_from. (default: %(default)f)") help="Where alpha ends with. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta_from", "--beta_from",
default=0.0, default=0.05,
type=float, type=float,
help="Where beta starts from, <= beta_to. (default: %(default)f)") help="Where beta starts from. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--beta_stride", "--num_betas",
default=0.01, default=20,
type=float, type=float,
help="Step length for varying beta. (default: %(default)f)") help="Number of candidate betas. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--beta_to", "--beta_to",
default=0.0, default=1.0,
type=float,
help="Where beta ends with. (default: %(default)f)")
parser.add_argument(
"--cutoff_prob",
default=0.99,
type=float, type=float,
help="Where beta ends with, >= beta_from. (default: %(default)f)") help="The cutoff probability of pruning"
"in beam search. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
@ -118,15 +124,11 @@ def tune():
Tune parameters alpha and beta on one minibatch. Tune parameters alpha and beta on one minibatch.
""" """
if not args.alpha_from <= args.alpha_to: if not args.num_alphas >= 0:
raise ValueError("alpha_from <= alpha_to doesn't satisfy!") raise ValueError("num_alphas must be non-negative!")
if not args.alpha_stride > 0:
raise ValueError("alpha_stride shouldn't be negative!")
if not args.beta_from <= args.beta_to: if not args.num_betas >= 0:
raise ValueError("beta_from <= beta_to doesn't satisfy!") raise ValueError("num_betas must be non-negative!")
if not args.beta_stride > 0:
raise ValueError("beta_stride shouldn't be negative!")
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
@ -171,6 +173,7 @@ def tune():
flatten=True, flatten=True,
sort_by_duration=False, sort_by_duration=False,
shuffle=False) shuffle=False)
# get one batch data for tuning
infer_data = test_batch_reader().next() infer_data = test_batch_reader().next()
# run inference # run inference
@ -182,11 +185,12 @@ def tune():
for i in xrange(0, len(infer_data)) for i in xrange(0, len(infer_data))
] ]
cand_alpha = np.arange(args.alpha_from, args.alpha_to + args.alpha_stride, # create grid for search
args.alpha_stride) cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_beta = np.arange(args.beta_from, args.beta_to + args.beta_stride, cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
args.beta_stride) params_grid = [(alpha, beta) for alpha in cand_alphas
params_grid = [(alpha, beta) for alpha in cand_alpha for beta in cand_beta] for beta in cand_betas]
## tune parameters in loop ## tune parameters in loop
for (alpha, beta) in params_grid: for (alpha, beta) in params_grid:
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
@ -200,8 +204,9 @@ def tune():
probs_seq=probs, probs_seq=probs,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, blank_id=len(vocab_list),
blank_id=len(vocab_list)) cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
wer_sum += wer(target_transcription, beam_search_result[0][1]) wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1 wer_counter += 1
# beam search using multiple processes # beam search using multiple processes
@ -210,9 +215,9 @@ def tune():
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
ext_scoring_func=ext_scorer, cutoff_prob=args.cutoff_prob,
blank_id=len(vocab_list), blank_id=len(vocab_list),
num_processes=1) ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])

Loading…
Cancel
Save