Merge branch 'develop' of https://github.com/PaddlePaddle/models into ds2_pcloud
commit
bbe47a4318
@ -0,0 +1,212 @@
|
||||
"""Evaluation for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
import argparse
|
||||
import gzip
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model import deep_speech2
|
||||
from decoder import *
|
||||
from lm.lm_scorer import LmScorer
|
||||
from error_rate import wer
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=100,
|
||||
type=int,
|
||||
help="Minibatch size for evaluation. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_conv_layers",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Convolution layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_rnn_layers",
|
||||
default=3,
|
||||
type=int,
|
||||
help="RNN layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--rnn_layer_size",
|
||||
default=512,
|
||||
type=int,
|
||||
help="RNN layer cell number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
default=True,
|
||||
type=distutils.util.strtobool,
|
||||
help="Use gpu or not. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_threads_data",
|
||||
default=multiprocessing.cpu_count(),
|
||||
type=int,
|
||||
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_processes_beam_search",
|
||||
default=multiprocessing.cpu_count(),
|
||||
type=int,
|
||||
help="Number of cpu processes for beam search. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--mean_std_filepath",
|
||||
default='mean_std.npz',
|
||||
type=str,
|
||||
help="Manifest path for normalizer. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--decode_method",
|
||||
default='beam_search',
|
||||
type=str,
|
||||
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language_model_path",
|
||||
default="lm/data/common_crawl_00.prune01111.trie.klm",
|
||||
type=str,
|
||||
help="Path for language model. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.26,
|
||||
type=float,
|
||||
help="Parameter associated with language model. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Parameter associated with word count. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--cutoff_prob",
|
||||
default=0.99,
|
||||
type=float,
|
||||
help="The cutoff probability of pruning"
|
||||
"in beam search. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=500,
|
||||
type=int,
|
||||
help="Width for beam search decoding. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--specgram_type",
|
||||
default='linear',
|
||||
type=str,
|
||||
help="Feature type of audio data: 'linear' (power spectrum)"
|
||||
" or 'mfcc'. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--decode_manifest_path",
|
||||
default='datasets/manifest.test',
|
||||
type=str,
|
||||
help="Manifest path for decoding. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--model_filepath",
|
||||
default='checkpoints/params.latest.tar.gz',
|
||||
type=str,
|
||||
help="Model filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--vocab_filepath",
|
||||
default='datasets/vocab/eng_vocab.txt',
|
||||
type=str,
|
||||
help="Vocabulary filepath. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def evaluate():
|
||||
"""Evaluate on whole test data for DeepSpeech2."""
|
||||
# initialize data generator
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_filepath,
|
||||
mean_std_filepath=args.mean_std_filepath,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=args.num_threads_data)
|
||||
|
||||
# create network config
|
||||
# paddle.data_type.dense_array is used for variable batch input.
|
||||
# The size 161 * 161 is only an placeholder value and the real shape
|
||||
# of input batch data will be induced during training.
|
||||
audio_data = paddle.layer.data(
|
||||
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
|
||||
text_data = paddle.layer.data(
|
||||
name="transcript_text",
|
||||
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
|
||||
output_probs = deep_speech2(
|
||||
audio_data=audio_data,
|
||||
text_data=text_data,
|
||||
dict_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_size=args.rnn_layer_size,
|
||||
is_inference=True)
|
||||
|
||||
# load parameters
|
||||
parameters = paddle.parameters.Parameters.from_tar(
|
||||
gzip.open(args.model_filepath))
|
||||
|
||||
# prepare infer data
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.decode_manifest_path,
|
||||
batch_size=args.batch_size,
|
||||
min_batch_size=1,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
|
||||
# define inferer
|
||||
inferer = paddle.inference.Inference(
|
||||
output_layer=output_probs, parameters=parameters)
|
||||
|
||||
# initialize external scorer for beam search decoding
|
||||
if args.decode_method == 'beam_search':
|
||||
ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
|
||||
|
||||
wer_counter, wer_sum = 0, 0.0
|
||||
for infer_data in batch_reader():
|
||||
# run inference
|
||||
infer_results = inferer.infer(input=infer_data)
|
||||
num_steps = len(infer_results) // len(infer_data)
|
||||
probs_split = [
|
||||
infer_results[i * num_steps:(i + 1) * num_steps]
|
||||
for i in xrange(0, len(infer_data))
|
||||
]
|
||||
# target transcription
|
||||
target_transcription = [
|
||||
''.join([
|
||||
data_generator.vocab_list[index] for index in infer_data[i][1]
|
||||
]) for i, probs in enumerate(probs_split)
|
||||
]
|
||||
# decode and print
|
||||
# best path decode
|
||||
if args.decode_method == "best_path":
|
||||
for i, probs in enumerate(probs_split):
|
||||
output_transcription = ctc_best_path_decoder(
|
||||
probs_seq=probs, vocabulary=data_generator.vocab_list)
|
||||
wer_sum += wer(target_transcription[i], output_transcription)
|
||||
wer_counter += 1
|
||||
# beam search decode
|
||||
elif args.decode_method == "beam_search":
|
||||
# beam search using multiple processes
|
||||
beam_search_results = ctc_beam_search_decoder_batch(
|
||||
probs_split=probs_split,
|
||||
vocabulary=data_generator.vocab_list,
|
||||
beam_size=args.beam_size,
|
||||
blank_id=len(data_generator.vocab_list),
|
||||
num_processes=args.num_processes_beam_search,
|
||||
ext_scoring_func=ext_scorer,
|
||||
cutoff_prob=args.cutoff_prob, )
|
||||
for i, beam_search_result in enumerate(beam_search_results):
|
||||
wer_sum += wer(target_transcription[i],
|
||||
beam_search_result[0][1])
|
||||
wer_counter += 1
|
||||
else:
|
||||
raise ValueError("Decoding method [%s] is not supported." %
|
||||
decode_method)
|
||||
|
||||
print("Final WER = %f" % (wer_sum / wer_counter))
|
||||
|
||||
|
||||
def main():
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
|
||||
evaluate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,68 @@
|
||||
"""External Scorer for Beam Search Decoder."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import kenlm
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LmScorer(object):
|
||||
"""External scorer to evaluate a prefix or whole sentence in
|
||||
beam search decoding, including the score from n-gram language
|
||||
model and word count.
|
||||
|
||||
:param alpha: Parameter associated with language model. Don't use
|
||||
language model when alpha = 0.
|
||||
:type alpha: float
|
||||
:param beta: Parameter associated with word count. Don't use word
|
||||
count when beta = 0.
|
||||
:type beta: float
|
||||
:model_path: Path to load language model.
|
||||
:type model_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path):
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
if not os.path.isfile(model_path):
|
||||
raise IOError("Invaid language model path: %s" % model_path)
|
||||
self._language_model = kenlm.LanguageModel(model_path)
|
||||
|
||||
# n-gram language model scoring
|
||||
def _language_model_score(self, sentence):
|
||||
#log10 prob of last word
|
||||
log_cond_prob = list(
|
||||
self._language_model.full_scores(sentence, eos=False))[-1][0]
|
||||
return np.power(10, log_cond_prob)
|
||||
|
||||
# word insertion term
|
||||
def _word_count(self, sentence):
|
||||
words = sentence.strip().split(' ')
|
||||
return len(words)
|
||||
|
||||
# reset alpha and beta
|
||||
def reset_params(self, alpha, beta):
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
|
||||
# execute evaluation
|
||||
def __call__(self, sentence, log=False):
|
||||
"""Evaluation function, gathering all the different scores
|
||||
and return the final one.
|
||||
|
||||
:param sentence: The input sentence for evalutation
|
||||
:type sentence: basestring
|
||||
:param log: Whether return the score in log representation.
|
||||
:type log: bool
|
||||
:return: Evaluation score, in the decimal or log.
|
||||
:rtype: float
|
||||
"""
|
||||
lm = self._language_model_score(sentence)
|
||||
word_cnt = self._word_count(sentence)
|
||||
if log == False:
|
||||
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
|
||||
else:
|
||||
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
|
||||
return score
|
@ -0,0 +1,19 @@
|
||||
echo "Downloading language model ..."
|
||||
|
||||
mkdir data
|
||||
|
||||
LM=common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
|
||||
wget -c http://paddlepaddle.bj.bcebos.com/model_zoo/speech/$LM -P ./data
|
||||
|
||||
echo "Checking md5sum ..."
|
||||
md5_tmp=`md5sum ./data/$LM | awk -F[' '] '{print $1}'`
|
||||
|
||||
if [ $MD5 != $md5_tmp ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
wget==3.2
|
||||
scipy==0.13.1
|
||||
resampy==0.1.5
|
||||
https://github.com/kpu/kenlm/archive/master.zip
|
||||
python_speech_features
|
||||
|
@ -0,0 +1,91 @@
|
||||
"""Test decoders."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from decoder import *
|
||||
|
||||
|
||||
class TestDecoders(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
|
||||
self.beam_size = 20
|
||||
self.probs_seq1 = [[
|
||||
0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
|
||||
0.18184413, 0.16493624
|
||||
], [
|
||||
0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
|
||||
0.0094893, 0.06890021
|
||||
], [
|
||||
0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
|
||||
0.08424043, 0.08120984
|
||||
], [
|
||||
0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
|
||||
0.05206269, 0.09772094
|
||||
], [
|
||||
0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
|
||||
0.41317442, 0.01946335
|
||||
], [
|
||||
0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
|
||||
0.04377724, 0.01457421
|
||||
]]
|
||||
self.probs_seq2 = [[
|
||||
0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
|
||||
0.04468023, 0.10903471
|
||||
], [
|
||||
0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
|
||||
0.10219457, 0.20640612
|
||||
], [
|
||||
0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
|
||||
0.12298585, 0.01654384
|
||||
], [
|
||||
0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
|
||||
0.22538587, 0.13483174
|
||||
], [
|
||||
0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
|
||||
0.07113197, 0.04139363
|
||||
], [
|
||||
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
|
||||
0.05294827, 0.22298418
|
||||
]]
|
||||
self.best_path_result = ["ac'bdc", "b'da"]
|
||||
self.beam_search_result = ['acdc', "b'a"]
|
||||
|
||||
def test_best_path_decoder_1(self):
|
||||
bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list)
|
||||
self.assertEqual(bst_result, self.best_path_result[0])
|
||||
|
||||
def test_best_path_decoder_2(self):
|
||||
bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list)
|
||||
self.assertEqual(bst_result, self.best_path_result[1])
|
||||
|
||||
def test_beam_search_decoder_1(self):
|
||||
beam_result = ctc_beam_search_decoder(
|
||||
probs_seq=self.probs_seq1,
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list,
|
||||
blank_id=len(self.vocab_list))
|
||||
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
||||
|
||||
def test_beam_search_decoder_2(self):
|
||||
beam_result = ctc_beam_search_decoder(
|
||||
probs_seq=self.probs_seq2,
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list,
|
||||
blank_id=len(self.vocab_list))
|
||||
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
||||
|
||||
def test_beam_search_decoder_batch(self):
|
||||
beam_results = ctc_beam_search_decoder_batch(
|
||||
probs_split=[self.probs_seq1, self.probs_seq2],
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list,
|
||||
blank_id=len(self.vocab_list),
|
||||
num_processes=24)
|
||||
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
||||
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,224 @@
|
||||
"""Parameters tuning for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
import argparse
|
||||
import gzip
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model import deep_speech2
|
||||
from decoder import *
|
||||
from lm.lm_scorer import LmScorer
|
||||
from error_rate import wer
|
||||
import utils
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
default=100,
|
||||
type=int,
|
||||
help="Number of samples for parameters tuning. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_conv_layers",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Convolution layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_rnn_layers",
|
||||
default=3,
|
||||
type=int,
|
||||
help="RNN layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--rnn_layer_size",
|
||||
default=512,
|
||||
type=int,
|
||||
help="RNN layer cell number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
default=True,
|
||||
type=distutils.util.strtobool,
|
||||
help="Use gpu or not. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_threads_data",
|
||||
default=multiprocessing.cpu_count(),
|
||||
type=int,
|
||||
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_processes_beam_search",
|
||||
default=multiprocessing.cpu_count(),
|
||||
type=int,
|
||||
help="Number of cpu processes for beam search. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--specgram_type",
|
||||
default='linear',
|
||||
type=str,
|
||||
help="Feature type of audio data: 'linear' (power spectrum)"
|
||||
" or 'mfcc'. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--mean_std_filepath",
|
||||
default='mean_std.npz',
|
||||
type=str,
|
||||
help="Manifest path for normalizer. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--decode_manifest_path",
|
||||
default='datasets/manifest.test',
|
||||
type=str,
|
||||
help="Manifest path for decoding. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--model_filepath",
|
||||
default='checkpoints/params.latest.tar.gz',
|
||||
type=str,
|
||||
help="Model filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--vocab_filepath",
|
||||
default='datasets/vocab/eng_vocab.txt',
|
||||
type=str,
|
||||
help="Vocabulary filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=500,
|
||||
type=int,
|
||||
help="Width for beam search decoding. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--language_model_path",
|
||||
default="lm/data/common_crawl_00.prune01111.trie.klm",
|
||||
type=str,
|
||||
help="Path for language model. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--alpha_from",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Where alpha starts from. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--num_alphas",
|
||||
default=14,
|
||||
type=int,
|
||||
help="Number of candidate alphas. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--alpha_to",
|
||||
default=0.36,
|
||||
type=float,
|
||||
help="Where alpha ends with. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beta_from",
|
||||
default=0.05,
|
||||
type=float,
|
||||
help="Where beta starts from. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--num_betas",
|
||||
default=20,
|
||||
type=float,
|
||||
help="Number of candidate betas. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--beta_to",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="Where beta ends with. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--cutoff_prob",
|
||||
default=0.99,
|
||||
type=float,
|
||||
help="The cutoff probability of pruning"
|
||||
"in beam search. (default: %(default)f)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def tune():
|
||||
"""Tune parameters alpha and beta on one minibatch."""
|
||||
|
||||
if not args.num_alphas >= 0:
|
||||
raise ValueError("num_alphas must be non-negative!")
|
||||
|
||||
if not args.num_betas >= 0:
|
||||
raise ValueError("num_betas must be non-negative!")
|
||||
|
||||
# initialize data generator
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_filepath,
|
||||
mean_std_filepath=args.mean_std_filepath,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=args.num_threads_data)
|
||||
|
||||
# create network config
|
||||
# paddle.data_type.dense_array is used for variable batch input.
|
||||
# The size 161 * 161 is only an placeholder value and the real shape
|
||||
# of input batch data will be induced during training.
|
||||
audio_data = paddle.layer.data(
|
||||
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
|
||||
text_data = paddle.layer.data(
|
||||
name="transcript_text",
|
||||
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
|
||||
output_probs = deep_speech2(
|
||||
audio_data=audio_data,
|
||||
text_data=text_data,
|
||||
dict_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_size=args.rnn_layer_size,
|
||||
is_inference=True)
|
||||
|
||||
# load parameters
|
||||
parameters = paddle.parameters.Parameters.from_tar(
|
||||
gzip.open(args.model_filepath))
|
||||
|
||||
# prepare infer data
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.decode_manifest_path,
|
||||
batch_size=args.num_samples,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
# get one batch data for tuning
|
||||
infer_data = batch_reader().next()
|
||||
|
||||
# run inference
|
||||
infer_results = paddle.infer(
|
||||
output_layer=output_probs, parameters=parameters, input=infer_data)
|
||||
num_steps = len(infer_results) // len(infer_data)
|
||||
probs_split = [
|
||||
infer_results[i * num_steps:(i + 1) * num_steps]
|
||||
for i in xrange(0, len(infer_data))
|
||||
]
|
||||
|
||||
# create grid for search
|
||||
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
|
||||
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
|
||||
params_grid = [(alpha, beta) for alpha in cand_alphas
|
||||
for beta in cand_betas]
|
||||
|
||||
ext_scorer = LmScorer(args.alpha_from, args.beta_from,
|
||||
args.language_model_path)
|
||||
## tune parameters in loop
|
||||
for alpha, beta in params_grid:
|
||||
wer_sum, wer_counter = 0, 0
|
||||
# reset scorer
|
||||
ext_scorer.reset_params(alpha, beta)
|
||||
# beam search using multiple processes
|
||||
beam_search_results = ctc_beam_search_decoder_batch(
|
||||
probs_split=probs_split,
|
||||
vocabulary=data_generator.vocab_list,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
blank_id=len(data_generator.vocab_list),
|
||||
num_processes=args.num_processes_beam_search,
|
||||
ext_scoring_func=ext_scorer, )
|
||||
for i, beam_search_result in enumerate(beam_search_results):
|
||||
target_transcription = ''.join([
|
||||
data_generator.vocab_list[index] for index in infer_data[i][1]
|
||||
])
|
||||
wer_sum += wer(target_transcription, beam_search_result[0][1])
|
||||
wer_counter += 1
|
||||
|
||||
print("alpha = %f\tbeta = %f\tWER = %f" %
|
||||
(alpha, beta, wer_sum / wer_counter))
|
||||
|
||||
|
||||
def main():
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
|
||||
tune()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue