You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
389 lines
20 KiB
389 lines
20 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "emerging-meter",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
|
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
|
" def convert_to_list(value, n, name, dtype=np.int):\n",
|
|
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
|
|
" from numpy.dual import register_func\n",
|
|
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
|
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
|
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
|
|
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
|
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
|
" long_ = _make_signed(np.long)\n",
|
|
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
|
|
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
|
|
" ulong = _make_unsigned(np.long)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import math\n",
|
|
"import random\n",
|
|
"import tarfile\n",
|
|
"import logging\n",
|
|
"import numpy as np\n",
|
|
"from collections import namedtuple\n",
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"import paddle\n",
|
|
"from paddle.io import Dataset\n",
|
|
"from paddle.io import DataLoader\n",
|
|
"from paddle.io import BatchSampler\n",
|
|
"from paddle.io import DistributedBatchSampler\n",
|
|
"from paddle import distributed as dist\n",
|
|
"\n",
|
|
"from data_utils.utility import read_manifest\n",
|
|
"from data_utils.augmentor.augmentation import AugmentationPipeline\n",
|
|
"from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n",
|
|
"from data_utils.speech import SpeechSegment\n",
|
|
"from data_utils.normalizer import FeatureNormalizer\n",
|
|
"\n",
|
|
"\n",
|
|
"from data_utils.dataset import (\n",
|
|
" DeepSpeech2Dataset,\n",
|
|
" DeepSpeech2DistributedBatchSampler,\n",
|
|
" DeepSpeech2BatchSampler,\n",
|
|
" SpeechCollator,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "excessive-american",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def create_dataloader(manifest_path,\t\n",
|
|
" vocab_filepath,\t\n",
|
|
" mean_std_filepath,\t\n",
|
|
" augmentation_config='{}',\t\n",
|
|
" max_duration=float('inf'),\t\n",
|
|
" min_duration=0.0,\t\n",
|
|
" stride_ms=10.0,\t\n",
|
|
" window_ms=20.0,\t\n",
|
|
" max_freq=None,\t\n",
|
|
" specgram_type='linear',\t\n",
|
|
" use_dB_normalization=True,\t\n",
|
|
" random_seed=0,\t\n",
|
|
" keep_transcription_text=False,\t\n",
|
|
" is_training=False,\t\n",
|
|
" batch_size=1,\t\n",
|
|
" num_workers=0,\t\n",
|
|
" sortagrad=False,\t\n",
|
|
" shuffle_method=None,\t\n",
|
|
" dist=False):\t\n",
|
|
"\n",
|
|
" dataset = DeepSpeech2Dataset(\t\n",
|
|
" manifest_path,\t\n",
|
|
" vocab_filepath,\t\n",
|
|
" mean_std_filepath,\t\n",
|
|
" augmentation_config=augmentation_config,\t\n",
|
|
" max_duration=max_duration,\t\n",
|
|
" min_duration=min_duration,\t\n",
|
|
" stride_ms=stride_ms,\t\n",
|
|
" window_ms=window_ms,\t\n",
|
|
" max_freq=max_freq,\t\n",
|
|
" specgram_type=specgram_type,\t\n",
|
|
" use_dB_normalization=use_dB_normalization,\t\n",
|
|
" random_seed=random_seed,\t\n",
|
|
" keep_transcription_text=keep_transcription_text)\t\n",
|
|
"\n",
|
|
" if dist:\t\n",
|
|
" batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n",
|
|
" dataset,\t\n",
|
|
" batch_size,\t\n",
|
|
" num_replicas=None,\t\n",
|
|
" rank=None,\t\n",
|
|
" shuffle=is_training,\t\n",
|
|
" drop_last=is_training,\t\n",
|
|
" sortagrad=is_training,\t\n",
|
|
" shuffle_method=shuffle_method)\t\n",
|
|
" else:\t\n",
|
|
" batch_sampler = DeepSpeech2BatchSampler(\t\n",
|
|
" dataset,\t\n",
|
|
" shuffle=is_training,\t\n",
|
|
" batch_size=batch_size,\t\n",
|
|
" drop_last=is_training,\t\n",
|
|
" sortagrad=is_training,\t\n",
|
|
" shuffle_method=shuffle_method)\t\n",
|
|
"\n",
|
|
" def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n",
|
|
" \"\"\"\t\n",
|
|
" Padding audio features with zeros to make them have the same shape (or\t\n",
|
|
" a user-defined shape) within one bach.\t\n",
|
|
"\n",
|
|
" If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n",
|
|
" as the target shape for padding. Otherwise, `padding_to` will be the\t\n",
|
|
" target shape (only refers to the second axis).\t\n",
|
|
"\n",
|
|
" If `flatten` is True, features will be flatten to 1darray.\t\n",
|
|
" \"\"\"\t\n",
|
|
" new_batch = []\t\n",
|
|
" # get target shape\t\n",
|
|
" max_length = max([audio.shape[1] for audio, text in batch])\t\n",
|
|
" if padding_to != -1:\t\n",
|
|
" if padding_to < max_length:\t\n",
|
|
" raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n",
|
|
" \"than any instance's shape in the batch\")\t\n",
|
|
" max_length = padding_to\t\n",
|
|
" max_text_length = max([len(text) for audio, text in batch])\t\n",
|
|
" # padding\t\n",
|
|
" padded_audios = []\t\n",
|
|
" audio_lens = []\t\n",
|
|
" texts, text_lens = [], []\t\n",
|
|
" for audio, text in batch:\t\n",
|
|
" padded_audio = np.zeros([audio.shape[0], max_length])\t\n",
|
|
" padded_audio[:, :audio.shape[1]] = audio\t\n",
|
|
" if flatten:\t\n",
|
|
" padded_audio = padded_audio.flatten()\t\n",
|
|
" padded_audios.append(padded_audio)\t\n",
|
|
" audio_lens.append(audio.shape[1])\t\n",
|
|
"\n",
|
|
" padded_text = np.zeros([max_text_length])\n",
|
|
" if is_training:\n",
|
|
" padded_text[:len(text)] = text\t# ids\n",
|
|
" else:\n",
|
|
" padded_text[:len(text)] = [ord(t) for t in text] # string\n",
|
|
" \n",
|
|
" texts.append(padded_text)\t\n",
|
|
" text_lens.append(len(text))\t\n",
|
|
"\n",
|
|
" padded_audios = np.array(padded_audios).astype('float32')\t\n",
|
|
" audio_lens = np.array(audio_lens).astype('int64')\t\n",
|
|
" texts = np.array(texts).astype('int32')\t\n",
|
|
" text_lens = np.array(text_lens).astype('int64')\t\n",
|
|
" return padded_audios, texts, audio_lens, text_lens\t\n",
|
|
"\n",
|
|
" loader = DataLoader(\t\n",
|
|
" dataset,\t\n",
|
|
" batch_sampler=batch_sampler,\t\n",
|
|
" collate_fn=partial(padding_batch, is_training=is_training),\t\n",
|
|
" num_workers=num_workers)\t\n",
|
|
" return loader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "naval-brave",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"import argparse\n",
|
|
"import functools\n",
|
|
"from utils.utility import add_arguments, print_arguments\n",
|
|
"parser = argparse.ArgumentParser(description=__doc__)\n",
|
|
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
|
|
"# yapf: disable\n",
|
|
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
|
|
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
|
|
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
|
|
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
|
|
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
|
|
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
|
|
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
|
|
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
|
|
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
|
|
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
|
|
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
|
|
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
|
|
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
|
|
" \"bi-directional RNNs. Not for GRU.\")\n",
|
|
"add_arg('infer_manifest', str,\n",
|
|
" 'examples/aishell/data/manifest.dev',\n",
|
|
" \"Filepath of manifest to infer.\")\n",
|
|
"add_arg('mean_std_path', str,\n",
|
|
" 'examples/aishell/data/mean_std.npz',\n",
|
|
" \"Filepath of normalizer's mean & std.\")\n",
|
|
"add_arg('vocab_path', str,\n",
|
|
" 'examples/aishell/data/vocab.txt',\n",
|
|
" \"Filepath of vocabulary.\")\n",
|
|
"add_arg('lang_model_path', str,\n",
|
|
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
|
|
" \"Filepath for language model.\")\n",
|
|
"add_arg('model_path', str,\n",
|
|
" 'examples/aishell/checkpoints/step_final',\n",
|
|
" \"If None, the training starts from scratch, \"\n",
|
|
" \"otherwise, it resumes from the pre-trained model.\")\n",
|
|
"add_arg('decoding_method', str,\n",
|
|
" 'ctc_beam_search',\n",
|
|
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
|
|
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
|
|
"add_arg('error_rate_type', str,\n",
|
|
" 'wer',\n",
|
|
" \"Error rate type for evaluation.\",\n",
|
|
" choices=['wer', 'cer'])\n",
|
|
"add_arg('specgram_type', str,\n",
|
|
" 'linear',\n",
|
|
" \"Audio feature type. Options: linear, mfcc.\",\n",
|
|
" choices=['linear', 'mfcc'])\n",
|
|
"# yapf: disable\n",
|
|
"args = parser.parse_args([])\n",
|
|
"print(vars(args))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "bearing-physics",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_reader = create_dataloader(\n",
|
|
" manifest_path=args.infer_manifest,\n",
|
|
" vocab_filepath=args.vocab_path,\n",
|
|
" mean_std_filepath=args.mean_std_path,\n",
|
|
" augmentation_config='{}',\n",
|
|
" #max_duration=float('inf'),\n",
|
|
" max_duration=27.0,\n",
|
|
" min_duration=0.0,\n",
|
|
" stride_ms=10.0,\n",
|
|
" window_ms=20.0,\n",
|
|
" max_freq=None,\n",
|
|
" specgram_type=args.specgram_type,\n",
|
|
" use_dB_normalization=True,\n",
|
|
" random_seed=0,\n",
|
|
" keep_transcription_text=True,\n",
|
|
" is_training=False,\n",
|
|
" batch_size=args.num_samples,\n",
|
|
" sortagrad=True,\n",
|
|
" shuffle_method=None,\n",
|
|
" dist=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "classified-melissa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
|
|
" [[22823, 26102, 20195, 37324, 0 , 0 ],\n",
|
|
" [22238, 26469, 23601, 22909, 0 , 0 ],\n",
|
|
" [20108, 26376, 22235, 26085, 0 , 0 ],\n",
|
|
" [36824, 35201, 20445, 25345, 32654, 24863],\n",
|
|
" [29042, 27748, 21463, 23456, 0 , 0 ]])\n",
|
|
"test raw 大时代里\n",
|
|
"test raw 煲汤受宠\n",
|
|
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
|
|
" [163, 167, 180, 186, 186])\n",
|
|
"test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
|
|
" [4, 4, 4, 6, 4])\n",
|
|
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
|
|
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
|
|
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
|
|
" ...,\n",
|
|
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
|
|
"\n",
|
|
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
|
|
" ...,\n",
|
|
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
|
|
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
|
|
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
|
|
"\n",
|
|
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
|
|
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
|
|
" ...,\n",
|
|
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
|
|
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
|
|
"\n",
|
|
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
|
|
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
|
|
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
|
|
" ...,\n",
|
|
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
|
|
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
|
|
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
|
|
"\n",
|
|
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
|
|
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
|
|
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
|
|
" ...,\n",
|
|
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
|
|
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
|
|
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
|
|
" print('test', text)\n",
|
|
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
|
|
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
|
|
" print('audio len', audio_len)\n",
|
|
" print('test len', text_len)\n",
|
|
" print('audio', audio)\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "unexpected-skating",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "minus-modern",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |