test and infer with decoding

pull/522/head
Hui Zhang 5 years ago
parent 314886d4c5
commit 4111d9ea1a

@ -11,6 +11,7 @@
```
`run_data.sh` will download dataset, generate manifests, collect normalizer's statistics and build vocabulary. Once the data preparation is done, you will find the data (only part of LibriSpeech) downloaded in `${MAIN_ROOT}/dataset/librispeech` and the corresponding manifest files generated in `${PWD}/data` as well as a mean stddev file and a vocabulary file. It has to be run for the very first time you run this dataset and is reusable for all further experiments.
- Train your own ASR model
```bash
@ -18,6 +19,7 @@
```
`run_train.sh` will start a training job, with training logs printed to stdout and model checkpoint of every pass/epoch saved to `${PWD}/checkpoints`. These checkpoints could be used for training resuming, inference, evaluation and deployment.
- Case inference with an existing model
```bash
@ -29,6 +31,7 @@
```bash
sh local/run_infer_golden.sh
```
- Evaluate an existing model
```bash

@ -38,14 +38,14 @@ training:
save_interval: 1000
valid_interval: 1000
decoding:
alpha: 2.5
batch_size: 128
beam_size: 500
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: models/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_beam_search
error_rate_type: wer
lang_model_path: models/lm/common_crawl_00.prune01111.trie.klm
num_proc_bsearch: 8

@ -8,36 +8,17 @@ if [ $? -ne 0 ]; then
fi
cd - > /dev/null
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python3 -u ${MAIN_ROOT}/infer.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--output ckpt
# infer
CUDA_VISIBLE_DEVICES=0 \
python3 -u $MAIN_ROOT/infer.py \
--num_samples=10 \
--beam_size=500 \
--num_proc_bsearch=8 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=2.5 \
--beta=0.3 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
--infer_manifest="data/manifest.test-clean" \
--mean_std_path="data/mean_std.npz" \
--vocab_path="data/vocab.txt" \
--model_path="checkpoints/step_final" \
--lang_model_path="$MAIN_ROOT/models/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \
--error_rate_type="wer" \
--specgram_type="linear"
if [ $? -ne 0 ]; then
echo "Failed in inference!"
exit 1
fi
exit 0

@ -8,32 +8,6 @@ if [ $? -ne 0 ]; then
fi
cd - > /dev/null
# evaluate model
#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
#python3 -u $MAIN_ROOT/test.py \
#--batch_size=128 \
#--beam_size=500 \
#--num_proc_bsearch=8 \
#--num_conv_layers=2 \
#--num_rnn_layers=3 \
#--rnn_layer_size=2048 \
#--alpha=2.5 \
#--beta=0.3 \
#--cutoff_prob=1.0 \
#--cutoff_top_n=40 \
#--use_gru=False \
#--use_gpu=True \
#--share_rnn_weights=True \
#--test_manifest="data/manifest.test-clean" \
#--mean_std_path="data/mean_std.npz" \
#--vocab_path="data/vocab.txt" \
#--model_path="checkpoints/step_final" \
#--lang_model_path="$MAIN_ROOT/models/lm/common_crawl_00.prune01111.trie.klm" \
#--decoding_method="ctc_beam_search" \
#--error_rate_type="wer" \
#--specgram_type="linear"
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python3 -u ${MAIN_ROOT}/test.py \
--device 'gpu' \

@ -1,20 +1,7 @@
#! /usr/bin/env bash
# train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0
#CUDA_VISIBLE_DEVICES=0,1,2,3 \
#python3 -u ${MAIN_ROOT}/train.py \
#--num_iter_print=1 \
#--save_epoch=1 \
#--num_samples=64 \
#--test_off=False \
#--is_local=True \
#--output_model_dir="./checkpoints/" \
#--shuffle_method="batch_shuffle_clipped" \
#CUDA_VISIBLE_DEVICES=0,1,2,3 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \

@ -1,4 +1,5 @@
#!/bin/bash
set -e
source path.sh

@ -13,186 +13,32 @@
# limitations under the License.
"""Inferer for DeepSpeech2 model."""
import sys
import argparse
import functools
from model_utils.model_check import check_cuda, check_version
from utils.utility import add_arguments, print_arguments
from utils.error_rate import wer, cer
from data_utils.data import DataGenerator
from data_utils.dataset import create_dataloader
from model_utils.model import DeepSpeech2Model
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('num_samples', int, 10, "# of samples to infer.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('infer_manifest', str,
'data/librispeech/manifest.dev-clean',
"Filepath of manifest to infer.")
add_arg('mean_std_path', str,
'data/librispeech/mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'data/librispeech/vocab.txt',
"Filepath of vocabulary.")
add_arg('lang_model_path', str,
'models/lm/common_crawl_00.prune01111.trie.klm',
"Filepath for language model.")
add_arg('model_path', str,
'./checkpoints/libri/step_final',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('decoding_method', str,
'ctc_beam_search',
"Decoding method. Options: ctc_beam_search, ctc_greedy",
choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('error_rate_type', str,
'wer',
"Error rate type for evaluation.",
choices=['wer', 'cer'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def infer():
"""Inference for DeepSpeech2."""
def main(config, args):
main_sp(config, args)
# check if set use_gpu=True in paddlepaddle cpu version
check_cuda(args.use_gpu)
# check if paddlepaddle version is satisfied
check_version()
# data_generator = DataGenerator(
# vocab_filepath=args.vocab_path,
# mean_std_filepath=args.mean_std_path,
# augmentation_config='{}',
# specgram_type=args.specgram_type,
# keep_transcription_text=True,
# place = place,
# is_training = False)
# batch_reader = data_generator.batch_reader_creator(
# manifest_path=args.infer_manifest,
# batch_size=args.num_samples,
# sortagrad=False,
# shuffle_method=None)
batch_reader = create_dataloader(
manifest_path=args.infer_manifest,
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type=args.specgram_type,
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=args.num_samples,
sortagrad=False,
shuffle_method=None)
#for audio, text, audio_len, text_len in batch_reader:
# print(audio.shape)
# print(text.shape)
# print(audio_len)
# print(text_len)
# break
reader = batch_reader()
infer_data = reader.next()
print(infer_data)
from model_utils.network2 import DeepSpeech2
feat_dim=161
model = DeepSpeech2(
feat_size=feat_dim,
dict_size=batch_reader.dataset.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights,
)
output = model(*infer_data)
print(output)
# ds2_model = DeepSpeech2Model(
# vocab_size=data_generator.vocab_size,
# num_conv_layers=args.num_conv_layers,
# num_rnn_layers=args.num_rnn_layers,
# rnn_layer_size=args.rnn_layer_size,
# use_gru=args.use_gru,
# share_rnn_weights=args.share_rnn_weights,
# place=place,
# init_from_pretrained_model=args.model_path)
# # decoders only accept string encoded in utf-8
# vocab_list = [chars for chars in data_generator.vocab_list]
# if args.decoding_method == "ctc_greedy":
# ds2_model.logger.info("start inference ...")
# probs_split = ds2_model.infer_batch_probs(
# infer_data=infer_data,
# feeding_dict=data_generator.feeding)
# result_transcripts = ds2_model.decode_batch_greedy(
# probs_split=probs_split,
# vocab_list=vocab_list)
# else:
# ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
# vocab_list)
# ds2_model.logger.info("start inference ...")
# probs_split= ds2_model.infer_batch_probs(
# infer_data=infer_data,
# feeding_dict=data_generator.feeding)
# result_transcripts= ds2_model.decode_batch_beam_search(
# probs_split=probs_split,
# beam_alpha=args.alpha,
# beam_beta=args.beta,
# beam_size=args.beam_size,
# cutoff_prob=args.cutoff_prob,
# cutoff_top_n=args.cutoff_top_n,
# vocab_list=vocab_list,
# num_processes=args.num_proc_bsearch)
# error_rate_func = cer if args.error_rate_type == 'cer' else wer
# target_transcripts = infer_data[1]
# for target, result in zip(target_transcripts, result_transcripts):
# print("\nTarget Transcription: %s\nOutput Transcription: %s" %
# (target, result))
# print("Current error rate [%s] = %f" %
# (args.error_rate_type, error_rate_func(target, result)))
# ds2_model.logger.info("finish inference")
def main():
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
infer()
if __name__ == '__main__':
main()
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -39,7 +39,7 @@ from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors
from utils.error_rate import char_errors, word_errors, cer, wer
class DeepSpeech2Trainer(Trainer):
@ -255,16 +255,10 @@ class DeepSpeech2Trainer(Trainer):
self.logger.info("Setup train/valid Dataloader!")
class DeepSpeech2Tester(Trainer):
class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def compute_losses(self, inputs, outputs):
_, texts, _, texts_len = inputs
logits, _, logits_len = outputs
loss = self.criterion(logits, texts, logits_len, texts_len)
return loss
def id2token(self, texts, texts_len, vocab_list):
trans = []
for text, n in zip(texts, texts_len):
@ -281,6 +275,7 @@ class DeepSpeech2Tester(Trainer):
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors
error_rate_func = cer if cfg.error_rate_type == 'cer' else wer
vocab_list = self.test_loader.dataset.vocab_list
target_transcripts = self.id2token(texts, texts_len, vocab_list)
@ -301,6 +296,11 @@ class DeepSpeech2Tester(Trainer):
errors_sum += errors
len_refs += len_ref
num_ins += 1
self.logger.info(
"\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
self.logger.info("Current error rate [%s] = %f" % (
cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
@ -377,8 +377,6 @@ class DeepSpeech2Tester(Trainer):
self.test()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
def setup_model(self):
config = self.config

@ -27,49 +27,6 @@ from model_utils.config import get_cfg_defaults
from model_utils.model import DeepSpeech2Tester as Tester
from utils.error_rate import char_errors, word_errors
# def evaluate():
# """Evaluate on whole test data for DeepSpeech2."""
# # decoders only accept string encoded in utf-8
# vocab_list = [chars for chars in data_generator.vocab_list]
# errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# errors_sum, len_refs, num_ins = 0.0, 0, 0
# ds2_model.logger.info("start evaluation ...")
# for infer_data in batch_reader():
# probs_split = ds2_model.infer_batch_probs(
# infer_data=infer_data, feeding_dict=data_generator.feeding)
# if args.decoding_method == "ctc_greedy":
# result_transcripts = ds2_model.decode_batch_greedy(
# probs_split=probs_split, vocab_list=vocab_list)
# else:
# result_transcripts = ds2_model.decode_batch_beam_search(
# probs_split=probs_split,
# beam_alpha=args.alpha,
# beam_beta=args.beta,
# beam_size=args.beam_size,
# cutoff_prob=args.cutoff_prob,
# cutoff_top_n=args.cutoff_top_n,
# vocab_list=vocab_list,
# num_processes=args.num_proc_bsearch)
# target_transcripts = infer_data[1]
# for target, result in zip(target_transcripts, result_transcripts):
# errors, len_ref = errors_func(target, result)
# errors_sum += errors
# len_refs += len_ref
# num_ins += 1
# print("Error rate [%s] (%d/?) = %f" %
# (args.error_rate_type, num_ins, errors_sum / len_refs))
# print("Final error rate [%s] (%d/%d) = %f" %
# (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
# ds2_model.logger.info("finish evaluation")
def main_sp(config, args):
exp = Tester(config, args)

Loading…
Cancel
Save