From c3a6d19af9a232f5d3700a65d0a9e7e44dc2e38c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 25 Feb 2021 05:35:17 +0000 Subject: [PATCH] fix ci add tune fix gru model bugs add dataset and model test --- .travis.yml | 2 +- dataloader.ipynb | 389 +++++++++++++++++++++++++ examples/aishell/conf/deepspeech2.yaml | 10 +- examples/aishell/local/infer.sh | 2 +- examples/aishell/local/tune.sh | 20 +- examples/tiny/conf/deepspeech2.yaml | 2 +- model_utils/config.py | 2 +- model_utils/model.py | 8 +- model_utils/network.py | 16 +- tests/network_test.py | 8 +- training/trainer.py | 19 +- tune.py | 261 +++++++++-------- 12 files changed, 574 insertions(+), 165 deletions(-) create mode 100644 dataloader.ipynb diff --git a/.travis.yml b/.travis.yml index b2af6a4c4..d1f4abb50 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ script: - exit_code=0 - .travis/precommit.sh || exit_code=$(( exit_code | $? )) - docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c - 'cd /py_unittest; sh .travis/unittest.sh' || exit_code=$(( exit_code | $? )) + 'cd /py_unittest; source env.sh; bash .travis/unittest.sh' || exit_code=$(( exit_code | $? )) exit $exit_code notifications: diff --git a/dataloader.ipynb b/dataloader.ipynb new file mode 100644 index 000000000..e2b8b3a0a --- /dev/null +++ b/dataloader.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "emerging-meter", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " def convert_to_list(value, n, name, dtype=np.int):\n", + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", + " from numpy.dual import register_func\n", + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " long_ = _make_signed(np.long)\n", + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " ulong = _make_unsigned(np.long)\n" + ] + } + ], + "source": [ + "import math\n", + "import random\n", + "import tarfile\n", + "import logging\n", + "import numpy as np\n", + "from collections import namedtuple\n", + "from functools import partial\n", + "\n", + "import paddle\n", + "from paddle.io import Dataset\n", + "from paddle.io import DataLoader\n", + "from paddle.io import BatchSampler\n", + "from paddle.io import DistributedBatchSampler\n", + "from paddle import distributed as dist\n", + "\n", + "from data_utils.utility import read_manifest\n", + "from data_utils.augmentor.augmentation import AugmentationPipeline\n", + "from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n", + "from data_utils.speech import SpeechSegment\n", + "from data_utils.normalizer import FeatureNormalizer\n", + "\n", + "\n", + "from data_utils.dataset import (\n", + " DeepSpeech2Dataset,\n", + " DeepSpeech2DistributedBatchSampler,\n", + " DeepSpeech2BatchSampler,\n", + " SpeechCollator,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "excessive-american", + "metadata": {}, + "outputs": [], + "source": [ + "def create_dataloader(manifest_path,\t\n", + " vocab_filepath,\t\n", + " mean_std_filepath,\t\n", + " augmentation_config='{}',\t\n", + " max_duration=float('inf'),\t\n", + " min_duration=0.0,\t\n", + " stride_ms=10.0,\t\n", + " window_ms=20.0,\t\n", + " max_freq=None,\t\n", + " specgram_type='linear',\t\n", + " use_dB_normalization=True,\t\n", + " random_seed=0,\t\n", + " keep_transcription_text=False,\t\n", + " is_training=False,\t\n", + " batch_size=1,\t\n", + " num_workers=0,\t\n", + " sortagrad=False,\t\n", + " shuffle_method=None,\t\n", + " dist=False):\t\n", + "\n", + " dataset = DeepSpeech2Dataset(\t\n", + " manifest_path,\t\n", + " vocab_filepath,\t\n", + " mean_std_filepath,\t\n", + " augmentation_config=augmentation_config,\t\n", + " max_duration=max_duration,\t\n", + " min_duration=min_duration,\t\n", + " stride_ms=stride_ms,\t\n", + " window_ms=window_ms,\t\n", + " max_freq=max_freq,\t\n", + " specgram_type=specgram_type,\t\n", + " use_dB_normalization=use_dB_normalization,\t\n", + " random_seed=random_seed,\t\n", + " keep_transcription_text=keep_transcription_text)\t\n", + "\n", + " if dist:\t\n", + " batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n", + " dataset,\t\n", + " batch_size,\t\n", + " num_replicas=None,\t\n", + " rank=None,\t\n", + " shuffle=is_training,\t\n", + " drop_last=is_training,\t\n", + " sortagrad=is_training,\t\n", + " shuffle_method=shuffle_method)\t\n", + " else:\t\n", + " batch_sampler = DeepSpeech2BatchSampler(\t\n", + " dataset,\t\n", + " shuffle=is_training,\t\n", + " batch_size=batch_size,\t\n", + " drop_last=is_training,\t\n", + " sortagrad=is_training,\t\n", + " shuffle_method=shuffle_method)\t\n", + "\n", + " def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n", + " \"\"\"\t\n", + " Padding audio features with zeros to make them have the same shape (or\t\n", + " a user-defined shape) within one bach.\t\n", + "\n", + " If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n", + " as the target shape for padding. Otherwise, `padding_to` will be the\t\n", + " target shape (only refers to the second axis).\t\n", + "\n", + " If `flatten` is True, features will be flatten to 1darray.\t\n", + " \"\"\"\t\n", + " new_batch = []\t\n", + " # get target shape\t\n", + " max_length = max([audio.shape[1] for audio, text in batch])\t\n", + " if padding_to != -1:\t\n", + " if padding_to < max_length:\t\n", + " raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n", + " \"than any instance's shape in the batch\")\t\n", + " max_length = padding_to\t\n", + " max_text_length = max([len(text) for audio, text in batch])\t\n", + " # padding\t\n", + " padded_audios = []\t\n", + " audio_lens = []\t\n", + " texts, text_lens = [], []\t\n", + " for audio, text in batch:\t\n", + " padded_audio = np.zeros([audio.shape[0], max_length])\t\n", + " padded_audio[:, :audio.shape[1]] = audio\t\n", + " if flatten:\t\n", + " padded_audio = padded_audio.flatten()\t\n", + " padded_audios.append(padded_audio)\t\n", + " audio_lens.append(audio.shape[1])\t\n", + "\n", + " padded_text = np.zeros([max_text_length])\n", + " if is_training:\n", + " padded_text[:len(text)] = text\t# ids\n", + " else:\n", + " padded_text[:len(text)] = [ord(t) for t in text] # string\n", + " \n", + " texts.append(padded_text)\t\n", + " text_lens.append(len(text))\t\n", + "\n", + " padded_audios = np.array(padded_audios).astype('float32')\t\n", + " audio_lens = np.array(audio_lens).astype('int64')\t\n", + " texts = np.array(texts).astype('int32')\t\n", + " text_lens = np.array(text_lens).astype('int64')\t\n", + " return padded_audios, texts, audio_lens, text_lens\t\n", + "\n", + " loader = DataLoader(\t\n", + " dataset,\t\n", + " batch_sampler=batch_sampler,\t\n", + " collate_fn=partial(padding_batch, is_training=is_training),\t\n", + " num_workers=num_workers)\t\n", + " return loader" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "naval-brave", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" + ] + } + ], + "source": [ + "import sys\n", + "import argparse\n", + "import functools\n", + "from utils.utility import add_arguments, print_arguments\n", + "parser = argparse.ArgumentParser(description=__doc__)\n", + "add_arg = functools.partial(add_arguments, argparser=parser)\n", + "# yapf: disable\n", + "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", + "add_arg('beam_size', int, 500, \"Beam search width.\")\n", + "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", + "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", + "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", + "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", + "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", + "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", + "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", + "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", + "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", + "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", + "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", + " \"bi-directional RNNs. Not for GRU.\")\n", + "add_arg('infer_manifest', str,\n", + " 'examples/aishell/data/manifest.dev',\n", + " \"Filepath of manifest to infer.\")\n", + "add_arg('mean_std_path', str,\n", + " 'examples/aishell/data/mean_std.npz',\n", + " \"Filepath of normalizer's mean & std.\")\n", + "add_arg('vocab_path', str,\n", + " 'examples/aishell/data/vocab.txt',\n", + " \"Filepath of vocabulary.\")\n", + "add_arg('lang_model_path', str,\n", + " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", + " \"Filepath for language model.\")\n", + "add_arg('model_path', str,\n", + " 'examples/aishell/checkpoints/step_final',\n", + " \"If None, the training starts from scratch, \"\n", + " \"otherwise, it resumes from the pre-trained model.\")\n", + "add_arg('decoding_method', str,\n", + " 'ctc_beam_search',\n", + " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", + " choices = ['ctc_beam_search', 'ctc_greedy'])\n", + "add_arg('error_rate_type', str,\n", + " 'wer',\n", + " \"Error rate type for evaluation.\",\n", + " choices=['wer', 'cer'])\n", + "add_arg('specgram_type', str,\n", + " 'linear',\n", + " \"Audio feature type. Options: linear, mfcc.\",\n", + " choices=['linear', 'mfcc'])\n", + "# yapf: disable\n", + "args = parser.parse_args([])\n", + "print(vars(args))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "bearing-physics", + "metadata": {}, + "outputs": [], + "source": [ + "batch_reader = create_dataloader(\n", + " manifest_path=args.infer_manifest,\n", + " vocab_filepath=args.vocab_path,\n", + " mean_std_filepath=args.mean_std_path,\n", + " augmentation_config='{}',\n", + " #max_duration=float('inf'),\n", + " max_duration=27.0,\n", + " min_duration=0.0,\n", + " stride_ms=10.0,\n", + " window_ms=20.0,\n", + " max_freq=None,\n", + " specgram_type=args.specgram_type,\n", + " use_dB_normalization=True,\n", + " random_seed=0,\n", + " keep_transcription_text=True,\n", + " is_training=False,\n", + " batch_size=args.num_samples,\n", + " sortagrad=True,\n", + " shuffle_method=None,\n", + " dist=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "classified-melissa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", + " [[22823, 26102, 20195, 37324, 0 , 0 ],\n", + " [22238, 26469, 23601, 22909, 0 , 0 ],\n", + " [20108, 26376, 22235, 26085, 0 , 0 ],\n", + " [36824, 35201, 20445, 25345, 32654, 24863],\n", + " [29042, 27748, 21463, 23456, 0 , 0 ]])\n", + "test raw 大时代里\n", + "test raw 煲汤受宠\n", + "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", + " [163, 167, 180, 186, 186])\n", + "test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", + " [4, 4, 4, 6, 4])\n", + "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", + " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", + " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", + " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", + " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", + " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", + "\n", + " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", + " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", + " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", + " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", + " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", + " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", + " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", + " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", + " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", + " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", + " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", + " ...,\n", + " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", + " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", + " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", + "\n", + " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", + " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", + " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", + " ...,\n", + " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", + " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", + " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" + ] + } + ], + "source": [ + "for idx, (audio, text, audio_len, text_len) in enumerate(batch_reader()):\n", + " print('test', text)\n", + " print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", + " print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", + " print('audio len', audio_len)\n", + " print('test len', text_len)\n", + " print('audio', audio)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "unexpected-skating", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "minus-modern", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/aishell/conf/deepspeech2.yaml b/examples/aishell/conf/deepspeech2.yaml index a0a9d6295..e2e08e1a9 100644 --- a/examples/aishell/conf/deepspeech2.yaml +++ b/examples/aishell/conf/deepspeech2.yaml @@ -19,7 +19,7 @@ data: target_dB: -20 random_seed: 0 keep_transcription_text: False - sortagrad: True + sortagrad: True shuffle_method: batch_shuffle num_workers: 0 model: @@ -27,7 +27,7 @@ model: num_rnn_layers: 3 rnn_layer_size: 1024 use_gru: True - share_rnn_weights: False + share_rnn_weights: False training: n_epoch: 30 lr: 5e-4 @@ -39,13 +39,13 @@ training: save_interval: 1000 valid_interval: 1000 decoding: - batch_size: 128 + batch_size: 10 error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: models/lm/zh_giga.no_cna_cmn.prune01244.klm alpha: 2.6 beta: 5.0 beam_size: 300 - cutoff_prob: 0.99 + cutoff_prob: 1.0 cutoff_top_n: 40 - num_proc_bsearch: 8 + num_proc_bsearch: 10 diff --git a/examples/aishell/local/infer.sh b/examples/aishell/local/infer.sh index d794ccc4e..bc413be11 100644 --- a/examples/aishell/local/infer.sh +++ b/examples/aishell/local/infer.sh @@ -13,7 +13,7 @@ python3 -u ${MAIN_ROOT}/infer.py \ --device 'gpu' \ --nproc 1 \ --config conf/deepspeech2.yaml \ ---checkpoint_path ckpt/checkpoints/step-3283 +--checkpoint_path ${1} if [ $? -ne 0 ]; then diff --git a/examples/aishell/local/tune.sh b/examples/aishell/local/tune.sh index 1b2f83db2..a11137706 100644 --- a/examples/aishell/local/tune.sh +++ b/examples/aishell/local/tune.sh @@ -5,19 +5,19 @@ python3 -u ${MAIN_ROOT}/tune.py \ --device 'gpu' \ --nproc 1 \ --config conf/deepspeech2.yaml \ ---num_batches=-1 \ +--num_batches=10 \ --batch_size=128 \ ---beam_size=500 \ ---num_proc_bsearch=12 \ ---num_alphas=45 \ ---num_betas=8 \ ---alpha_from=1.0 \ ---alpha_to=3.2 \ ---beta_from=0.1 \ ---beta_to=0.45 \ +--beam_size=300 \ +--num_proc_bsearch=8 \ +--num_alphas=10 \ +--num_betas=10 \ +--alpha_from=0.0 \ +--alpha_to=5.0 \ +--beta_from=-6 \ +--beta_to=6 \ --cutoff_prob=1.0 \ --cutoff_top_n=40 \ ---checkpoint_path ${1} +--checkpoint_path ${1} if [ $? -ne 0 ]; then echo "Failed in tuning!" diff --git a/examples/tiny/conf/deepspeech2.yaml b/examples/tiny/conf/deepspeech2.yaml index ab4cb510a..dc7d59d47 100644 --- a/examples/tiny/conf/deepspeech2.yaml +++ b/examples/tiny/conf/deepspeech2.yaml @@ -26,7 +26,7 @@ model: num_conv_layers: 2 num_rnn_layers: 3 rnn_layer_size: 2048 - use_gru: True + use_gru: False share_rnn_weights: True training: n_epoch: 20 diff --git a/model_utils/config.py b/model_utils/config.py index 79436110f..a6b99a61d 100644 --- a/model_utils/config.py +++ b/model_utils/config.py @@ -46,7 +46,7 @@ _C.model = CN( num_conv_layers=2, #Number of stacking convolution layers. num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=False, #Use gru if set True. Use simple rnn if set False. + use_gru=True, #Use gru if set True. Use simple rnn if set False. share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. )) diff --git a/model_utils/model.py b/model_utils/model.py index f38de6db7..6520d94a3 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -242,6 +242,7 @@ class DeepSpeech2Trainer(Trainer): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) if self.parallel: @@ -329,7 +330,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) - collate_fn = SpeechCollator() + collate_fn = SpeechCollator(is_training=True) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -449,7 +450,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): output_dir.mkdir(parents=True, exist_ok=True) else: output_dir = Path( - self.args.checkpoint_path).expanduser().parent / "infer" + self.args.checkpoint_path).expanduser().parent.parent / "infer" output_dir.mkdir(parents=True, exist_ok=True) self.output_dir = output_dir @@ -485,6 +486,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) if self.parallel: @@ -498,6 +500,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_dataloader(self): config = self.config + # return raw text test_dataset = DeepSpeech2Dataset( config.data.test_manifest, config.data.vocab_filepath, @@ -516,6 +519,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): random_seed=config.data.random_seed, keep_transcription_text=True) + # return text ord id self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, diff --git a/model_utils/network.py b/model_utils/network.py index b536865be..c8fa95dc4 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -22,6 +22,8 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +from utils import checkpoint + from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_beam_search_decoder_batch @@ -89,7 +91,7 @@ class ConvBn(nn.Layer): stride=stride, padding=padding, weight_attr=None, - bias_attr=None, + bias_attr=False, data_format='NCHW') self.bn = nn.BatchNorm2D( @@ -387,6 +389,7 @@ class BiGRUWithBN(nn.Layer): def __init__(self, i_size, h_size, act): super().__init__() hidden_size = h_size * 3 + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) self.fw_bn = nn.BatchNorm1D( hidden_size, bias_attr=None, data_format='NLC') @@ -494,7 +497,7 @@ class DeepSpeech2(nn.Layer): dict_size, num_conv_layers=2, num_rnn_layers=3, - rnn_size=256, + rnn_size=1024, use_gru=False, share_rnn_weights=True): super().__init__() @@ -684,9 +687,10 @@ class DeepSpeech2(nn.Layer): lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes): _, probs, _ = self.predict(audio, audio_len) - return self.decode_probs( - probs, vocab_list, decoding_method, lang_model_path, beam_alpha, - beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes) + return self.decode_probs(probs.numpy(), vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + num_processes) def from_pretrained(self, checkpoint_path): """Build a model from a pretrained model. @@ -704,7 +708,7 @@ class DeepSpeech2(nn.Layer): The model build from pretrined result. """ checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) - return model + return def ctc_loss(logits, diff --git a/tests/network_test.py b/tests/network_test.py index ddd3991ed..7e35c05cc 100644 --- a/tests/network_test.py +++ b/tests/network_test.py @@ -63,7 +63,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=True, share_rnn_weights=False, ) - probs = model2(audio, text, audio_len, text_len) + logits, probs, logits_len = model2(audio, text, audio_len, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -75,7 +75,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=False, share_rnn_weights=True, ) - probs = model3(audio, text, audio_len, text_len) + logits, probs, logits_len = model3(audio, text, audio_len, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -87,7 +87,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=True, share_rnn_weights=True, ) - probs = model4(audio, text, audio_len, text_len) + logits, probs, logits_len = model4(audio, text, audio_len, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -99,6 +99,6 @@ if __name__ == '__main__': rnn_size=1024, use_gru=False, share_rnn_weights=False, ) - probs = model5(audio, text, audio_len, text_len) + logits, probs, logits_len = model5(audio, text, audio_len, text_len) print('probs.shape', probs.shape) print("-----------------") diff --git a/training/trainer.py b/training/trainer.py index 1dcca5aab..3fac31d70 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -267,30 +267,29 @@ class Trainer(): when = 'D' backup = 7 format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') logger = logging.getLogger(__name__) logger.setLevel("INFO") - formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') - stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) - if not hasattr(self, 'output_dir'): - self.logger = logger - return + # if not hasattr(self, 'output_dir'): + # self.logger = logger + # return log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) # file_handler = logging.FileHandler(str(log_file)) # file_handler.setFormatter(formatter) # logger.addHandler(file_handler) - handler = logging.handlers.TimedRotatingFileHandler( - str(self.output_dir / "warning.log"), when=when, backupCount=backup) - handler.setLevel(logging.WARNING) - handler.setFormatter(formatter) - logger.addHandler(handler) + # handler = logging.handlers.TimedRotatingFileHandler( + # str(self.output_dir / "warning.log"), when=when, backupCount=backup) + # handler.setLevel(logging.WARNING) + # handler.setFormatter(formatter) + # logger.addHandler(handler) # global logger stdout = False diff --git a/tune.py b/tune.py index ad48bcb67..b269265ae 100644 --- a/tune.py +++ b/tune.py @@ -21,107 +21,69 @@ import functools import gzip import logging import paddle.fluid as fluid -import _init_paths -from data_utils.data import DataGenerator -from model_utils.model import DeepSpeech2Model + +from training.cli import default_argument_parser +from model_utils.config import get_cfg_defaults + +from data_utils.dataset import SpeechCollator +from data_utils.dataset import DeepSpeech2Dataset +from data_utils.dataset import DeepSpeech2DistributedBatchSampler +from data_utils.dataset import DeepSpeech2BatchSampler +from paddle.io import DataLoader + +from model_utils.network import DeepSpeech2 +from model_utils.network import DeepSpeech2Loss + from utils.error_rate import char_errors, word_errors from utils.utility import add_arguments, print_arguments -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('num_batches', int, -1, "# of batches tuning on. " - "Default -1, on whole dev set.") -add_arg('batch_size', int, 256, "# of samples per batch.") -add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).") - -add_arg('beam_size', int, 500, "Beam search width.") -add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") -add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.") -add_arg('num_betas', int, 8, "# of beta candidates for tuning.") -add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.") -add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.") -add_arg('beta_from', float, 0.1, "Where beta starts tuning from.") -add_arg('beta_to', float, 0.45, "Where beta ends tuning with.") -add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") -add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") - -add_arg('num_conv_layers', int, 2, "# of convolution layers.") -add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") -add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") -add_arg('use_gpu', bool, True, "Use GPU or not.") -add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " - "bi-directional RNNs. Not for GRU.") - -add_arg('tune_manifest', str, - 'data/librispeech/manifest.dev-clean', - "Filepath of manifest to tune.") -add_arg('mean_std_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of normalizer's mean & std.") -add_arg('vocab_path', str, - 'data/librispeech/vocab.txt', - "Filepath of vocabulary.") -add_arg('lang_model_path', str, - 'models/lm/common_crawl_00.prune01111.trie.klm', - "Filepath for language model.") -add_arg('model_path', str, - './checkpoints/libri/params.latest.tar.gz', - "If None, the training starts from scratch, " - "otherwise, it resumes from the pre-trained model.") -add_arg('error_rate_type', str, - 'wer', - "Error rate type for evaluation.", - choices=['wer', 'cer']) -add_arg('specgram_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc.", - choices=['linear', 'mfcc']) -# yapf: disable -args = parser.parse_args() - - -def tune(): + +def tune(config, args): """Tune parameters alpha and beta incrementally.""" if not args.num_alphas >= 0: raise ValueError("num_alphas must be non-negative!") if not args.num_betas >= 0: raise ValueError("num_betas must be non-negative!") - if args.use_gpu: - place = fluid.CUDAPlace(0) - else: - place = fluid.CPUPlace() - - data_generator = DataGenerator( - vocab_filepath=args.vocab_path, - mean_std_filepath=args.mean_std_path, - augmentation_config='{}', - specgram_type=args.specgram_type, - keep_transcription_text=True, - place = place, - is_training = False) - - batch_reader = data_generator.batch_reader_creator( - manifest_path=args.tune_manifest, - batch_size=args.batch_size, - sortagrad=False, - shuffle_method=None) - - ds2_model = DeepSpeech2Model( - vocab_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_layer_size=args.rnn_layer_size, - use_gru=args.use_gru, - place=place, - init_from_pretrained_model=args.model_path, - share_rnn_weights=args.share_rnn_weights) + dev_dataset = DeepSpeech2Dataset( + config.data.dev_manifest, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config="{}", + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + random_seed=config.data.random_seed, + keep_transcription_text=True) + + valid_loader = DataLoader( + dev_dataset, + batch_size=config.data.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator(is_training=False)) + + model = DeepSpeech2( + feat_size=valid_loader.dataset.feature_size, + dict_size=valid_loader.dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + share_rnn_weights=config.model.share_rnn_weights) + model.from_pretrained(args.checkpoint_path) + model.eval() # decoders only accept string encoded in utf-8 - vocab_list = [chars for chars in data_generator.vocab_list] - errors_func = char_errors if args.error_rate_type == 'cer' else word_errors + vocab_list = valid_loader.dataset.vocab_list + errors_func = char_errors if config.decoding.error_rate_type == 'cer' else word_errors + # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) @@ -131,34 +93,42 @@ def tune(): err_sum = [0.0 for i in range(len(params_grid))] err_ave = [0.0 for i in range(len(params_grid))] - num_ins, len_refs, cur_batch = 0, 0, 0 # initialize external scorer - ds2_model.init_ext_scorer(args.alpha_from, args.beta_from, - args.lang_model_path, vocab_list) + model.init_decode(args.alpha_from, args.beta_from, + config.decoding.lang_model_path, vocab_list, + config.decoding.decoding_method) ## incremental tuning parameters over multiple batches - ds2_model.logger.info("start tuning ...") - for infer_data in batch_reader(): + print("start tuning ...") + for infer_data in valid_loader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - probs_split = ds2_model.infer_batch_probs( - infer_data=infer_data, - feeding_dict=data_generator.feeding) - target_transcripts = infer_data[1] - num_ins += len(target_transcripts) + def ordid2token(texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + audio, text, audio_len, text_len = infer_data + _, probs, _ = model.predict(audio, audio_len) + target_transcripts = ordid2token(text, text_len) + num_ins += audio.shape[0] + # grid search for index, (alpha, beta) in enumerate(params_grid): - result_transcripts = ds2_model.decode_batch_beam_search( - probs_split=probs_split, - beam_alpha=alpha, - beam_beta=beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - num_processes=args.num_proc_bsearch) + print(f"tuneing: alpha={alpha} beta={beta}") + result_transcripts = model.decode_probs( + probs.numpy(), vocab_list, config.decoding.decoding_method, + config.decoding.lang_model_path, alpha, beta, + config.decoding.beam_size, config.decoding.cutoff_prob, + config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) + for target, result in zip(target_transcripts, result_transcripts): + #print(f"tuneing: {target} {result}") errors, len_ref = errors_func(target, result) err_sum[index] += errors @@ -171,37 +141,80 @@ def tune(): if index % 2 == 0: sys.stdout.write('.') sys.stdout.flush() + print(f"tuneing: one grid done!") # output on-line tuning result at the end of current batch err_ave_min = min(err_ave) min_index = err_ave.index(err_ave_min) print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " - " min [%s] = %f" %(cur_batch, num_ins, - "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1], - args.error_rate_type, err_ave_min)) + " min [%s] = %f" % + (cur_batch, num_ins, "%.3f" % params_grid[min_index][0], "%.3f" % + params_grid[min_index][1], args.error_rate_type, err_ave_min)) cur_batch += 1 # output WER/CER at every (alpha, beta) - print("\nFinal %s:\n" % args.error_rate_type) + print("\nFinal %s:\n" % config.decoding.error_rate_type) for index in range(len(params_grid)): - print("(alpha, beta) = (%s, %s), [%s] = %f" - % ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], - args.error_rate_type, err_ave[index])) + print("(alpha, beta) = (%s, %s), [%s] = %f" % + ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], + config.decoding.error_rate_type, err_ave[index])) err_ave_min = min(err_ave) min_index = err_ave.index(err_ave_min) - print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" - % (cur_batch, "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1])) + print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" % + (cur_batch, "%.3f" % params_grid[min_index][0], + "%.3f" % params_grid[min_index][1])) ds2_model.logger.info("finish tuning") -def main(): +def main_sp(config, args): + tune(config, args) + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + add_arg = functools.partial(add_arguments, argparser=parser) + add_arg('num_batches', int, -1, "# of batches tuning on. " + "Default -1, on whole dev set.") + add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.") + add_arg('num_betas', int, 8, "# of beta candidates for tuning.") + add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.") + add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.") + add_arg('beta_from', float, 0.1, "Where beta starts tuning from.") + add_arg('beta_to', float, 0.45, "Where beta ends tuning with.") + + add_arg('batch_size', int, 256, "# of samples per batch.") + add_arg('beam_size', int, 500, "Beam search width.") + add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") + add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") + add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") + + args = parser.parse_args() print_arguments(args) - tune() + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + + config.data.batch_size = args.batch_size + config.decoding.beam_size = args.beam_size + config.decoding.num_proc_bsearch = args.num_proc_bsearch + config.decoding.cutoff_prob = args.cutoff_prob + config.decoding.cutoff_top_n = args.cutoff_top_n + + config.freeze() + print(config) + + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) -if __name__ == '__main__': - main() + main(config, args)