# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Beam search parameters tuning for DeepSpeech2 model.""" import sys import os import numpy as np import argparse import functools import gzip import logging from paddle.io import DataLoader from deepspeech.utils import error_rate from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.training.cli import default_argument_parser from deepspeech.exps.deepspeech2.config import get_cfg_defaults def tune(config, args): """Tune parameters alpha and beta incrementally.""" if not args.num_alphas >= 0: raise ValueError("num_alphas must be non-negative!") if not args.num_betas >= 0: raise ValueError("num_betas must be non-negative!") dev_dataset = ManifestDataset( config.data.dev_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, stride_ms=config.data.stride_ms, window_ms=config.data.window_ms, n_fft=config.data.n_fft, max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, keep_transcription_text=True) valid_loader = DataLoader( dev_dataset, batch_size=config.data.batch_size, shuffle=False, drop_last=False, collate_fn=SpeechCollator(is_training=False)) model = DeepSpeech2Model.from_pretrained(dev_dataset, config, args.checkpoint_path) model.eval() # decoders only accept string encoded in utf-8 vocab_list = valid_loader.dataset.vocab_list errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] err_sum = [0.0 for i in range(len(params_grid))] err_ave = [0.0 for i in range(len(params_grid))] num_ins, len_refs, cur_batch = 0, 0, 0 # initialize external scorer model.decoder.init_decode(args.alpha_from, args.beta_from, config.decoding.lang_model_path, vocab_list, config.decoding.decoding_method) ## incremental tuning parameters over multiple batches print("start tuning ...") for infer_data in valid_loader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break def ordid2token(texts, texts_len): """ ord() id to chr() chr """ trans = [] for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] trans.append(''.join([chr(i) for i in ids])) return trans audio, text, audio_len, text_len = infer_data target_transcripts = ordid2token(text, text_len) num_ins += audio.shape[0] # model infer eouts, eouts_len = model.encoder(audio, audio_len) probs = model.decoder.probs(eouts) # grid search for index, (alpha, beta) in enumerate(params_grid): print(f"tuneing: alpha={alpha} beta={beta}") result_transcripts = model.decoder.decode_probs( probs.numpy(), eouts_len, vocab_list, config.decoding.decoding_method, config.decoding.lang_model_path, alpha, beta, config.decoding.beam_size, config.decoding.cutoff_prob, config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) err_sum[index] += errors # accumulate the length of references of every batchπ # in the first iteration if args.alpha_from == alpha and args.beta_from == beta: len_refs += len_ref err_ave[index] = err_sum[index] / len_refs if index % 2 == 0: sys.stdout.write('.') sys.stdout.flush() print(f"tuneing: one grid done!") # output on-line tuning result at the end of current batch err_ave_min = min(err_ave) min_index = err_ave.index(err_ave_min) print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " " min [%s] = %f" % (cur_batch, num_ins, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1], config.decoding.error_rate_type, err_ave_min)) cur_batch += 1 # output WER/CER at every (alpha, beta) print("\nFinal %s:\n" % config.decoding.error_rate_type) for index in range(len(params_grid)): print("(alpha, beta) = (%s, %s), [%s] = %f" % ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], config.decoding.error_rate_type, err_ave[index])) err_ave_min = min(err_ave) min_index = err_ave.index(err_ave_min) print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" % (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1])) print("finish tuning") def main(config, args): tune(config, args) if __name__ == "__main__": parser = default_argument_parser() add_arg = functools.partial(add_arguments, argparser=parser) add_arg('num_batches', int, -1, "# of batches tuning on. " "Default -1, on whole dev set.") add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.") add_arg('num_betas', int, 8, "# of beta candidates for tuning.") add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.") add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.") add_arg('beta_from', float, 0.1, "Where beta starts tuning from.") add_arg('beta_to', float, 0.45, "Where beta ends tuning with.") add_arg('batch_size', int, 256, "# of samples per batch.") add_arg('beam_size', int, 500, "Beam search width.") add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") args = parser.parse_args() print_arguments(args) # https://yaml.org/type/float.html config = get_cfg_defaults() if args.config: config.merge_from_file(args.config) if args.opts: config.merge_from_list(args.opts) config.data.batch_size = args.batch_size config.decoding.beam_size = args.beam_size config.decoding.num_proc_bsearch = args.num_proc_bsearch config.decoding.cutoff_prob = args.cutoff_prob config.decoding.cutoff_top_n = args.cutoff_top_n config.freeze() print(config) if args.dump_config: with open(args.dump_config, 'w') as f: print(config, file=f) main(config, args)