You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
192 lines
7.5 KiB
192 lines
7.5 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Beam search parameters tuning for DeepSpeech2 model."""
|
|
import functools
|
|
import sys
|
|
|
|
import numpy as np
|
|
from paddle.io import DataLoader
|
|
|
|
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
|
|
from deepspeech.io.collator import SpeechCollator
|
|
from deepspeech.io.dataset import ManifestDataset
|
|
from deepspeech.models.deepspeech2 import DeepSpeech2Model
|
|
from deepspeech.training.cli import default_argument_parser
|
|
from deepspeech.utils import error_rate
|
|
from deepspeech.utils.utility import add_arguments
|
|
from deepspeech.utils.utility import print_arguments
|
|
|
|
|
|
def tune(config, args):
|
|
"""Tune parameters alpha and beta incrementally."""
|
|
if not args.num_alphas >= 0:
|
|
raise ValueError("num_alphas must be non-negative!")
|
|
if not args.num_betas >= 0:
|
|
raise ValueError("num_betas must be non-negative!")
|
|
config.defrost()
|
|
config.data.manfiest = config.data.dev_manifest
|
|
config.data.augmentation_config = ""
|
|
config.data.keep_transcription_text = True
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
valid_loader = DataLoader(
|
|
dev_dataset,
|
|
batch_size=config.data.batch_size,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
collate_fn=SpeechCollator(keep_transcription_text=True))
|
|
|
|
model = DeepSpeech2Model.from_pretrained(valid_loader, config,
|
|
args.checkpoint_path)
|
|
model.eval()
|
|
|
|
# decoders only accept string encoded in utf-8
|
|
vocab_list = valid_loader.dataset.vocab_list
|
|
errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
# create grid for search
|
|
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
|
|
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
|
|
params_grid = [(alpha, beta) for alpha in cand_alphas
|
|
for beta in cand_betas]
|
|
|
|
err_sum = [0.0 for i in range(len(params_grid))]
|
|
err_ave = [0.0 for i in range(len(params_grid))]
|
|
|
|
num_ins, len_refs, cur_batch = 0, 0, 0
|
|
# initialize external scorer
|
|
model.decoder.init_decode(args.alpha_from, args.beta_from,
|
|
config.decoding.lang_model_path, vocab_list,
|
|
config.decoding.decoding_method)
|
|
## incremental tuning parameters over multiple batches
|
|
print("start tuning ...")
|
|
for infer_data in valid_loader():
|
|
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
|
|
break
|
|
|
|
def ordid2token(texts, texts_len):
|
|
""" ord() id to chr() chr """
|
|
trans = []
|
|
for text, n in zip(texts, texts_len):
|
|
n = n.numpy().item()
|
|
ids = text[:n]
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
return trans
|
|
|
|
audio, audio_len, text, text_len = infer_data
|
|
target_transcripts = ordid2token(text, text_len)
|
|
num_ins += audio.shape[0]
|
|
|
|
# model infer
|
|
eouts, eouts_len = model.encoder(audio, audio_len)
|
|
probs = model.decoder.softmax(eouts)
|
|
|
|
# grid search
|
|
for index, (alpha, beta) in enumerate(params_grid):
|
|
print(f"tuneing: alpha={alpha} beta={beta}")
|
|
result_transcripts = model.decoder.decode_probs(
|
|
probs.numpy(), eouts_len, vocab_list,
|
|
config.decoding.decoding_method,
|
|
config.decoding.lang_model_path, alpha, beta,
|
|
config.decoding.beam_size, config.decoding.cutoff_prob,
|
|
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
errors, len_ref = errors_func(target, result)
|
|
err_sum[index] += errors
|
|
|
|
# accumulate the length of references of every batchπ
|
|
# in the first iteration
|
|
if args.alpha_from == alpha and args.beta_from == beta:
|
|
len_refs += len_ref
|
|
|
|
err_ave[index] = err_sum[index] / len_refs
|
|
if index % 2 == 0:
|
|
sys.stdout.write('.')
|
|
sys.stdout.flush()
|
|
print("tuneing: one grid done!")
|
|
|
|
# output on-line tuning result at the end of current batch
|
|
err_ave_min = min(err_ave)
|
|
min_index = err_ave.index(err_ave_min)
|
|
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
|
|
" min [%s] = %f" %
|
|
(cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
|
|
"%.3f" % params_grid[min_index][1],
|
|
config.decoding.error_rate_type, err_ave_min))
|
|
cur_batch += 1
|
|
|
|
# output WER/CER at every (alpha, beta)
|
|
print("\nFinal %s:\n" % config.decoding.error_rate_type)
|
|
for index in range(len(params_grid)):
|
|
print("(alpha, beta) = (%s, %s), [%s] = %f" %
|
|
("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
|
|
config.decoding.error_rate_type, err_ave[index]))
|
|
|
|
err_ave_min = min(err_ave)
|
|
min_index = err_ave.index(err_ave_min)
|
|
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" %
|
|
(cur_batch, "%.3f" % params_grid[min_index][0],
|
|
"%.3f" % params_grid[min_index][1]))
|
|
|
|
print("finish tuning")
|
|
|
|
|
|
def main(config, args):
|
|
tune(config, args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = default_argument_parser()
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
add_arg('num_batches', int, -1, "# of batches tuning on. "
|
|
"Default -1, on whole dev set.")
|
|
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
|
|
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
|
|
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
|
|
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
|
|
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
|
|
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
|
|
|
|
add_arg('batch_size', int, 256, "# of samples per batch.")
|
|
add_arg('beam_size', int, 500, "Beam search width.")
|
|
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
|
|
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
|
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
|
|
|
args = parser.parse_args()
|
|
print_arguments(args, globals())
|
|
|
|
# https://yaml.org/type/float.html
|
|
config = get_cfg_defaults()
|
|
if args.config:
|
|
config.merge_from_file(args.config)
|
|
if args.opts:
|
|
config.merge_from_list(args.opts)
|
|
|
|
config.data.batch_size = args.batch_size
|
|
config.decoding.beam_size = args.beam_size
|
|
config.decoding.num_proc_bsearch = args.num_proc_bsearch
|
|
config.decoding.cutoff_prob = args.cutoff_prob
|
|
config.decoding.cutoff_top_n = args.cutoff_top_n
|
|
|
|
config.freeze()
|
|
print(config)
|
|
|
|
if args.dump_config:
|
|
with open(args.dump_config, 'w') as f:
|
|
print(config, file=f)
|
|
|
|
main(config, args)
|