You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/.notebook/dataloader.ipynb

389 lines
20 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n",
"from data_utils.utility import read_manifest\n",
"from data_utils.augmentor.augmentation import AugmentationPipeline\n",
"from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from data_utils.speech import SpeechSegment\n",
"from data_utils.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from data_utils.dataset import (\n",
" DeepSpeech2Dataset,\n",
" DeepSpeech2DistributedBatchSampler,\n",
" DeepSpeech2BatchSampler,\n",
" SpeechCollator,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader(manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config='{}',\t\n",
" max_duration=float('inf'),\t\n",
" min_duration=0.0,\t\n",
" stride_ms=10.0,\t\n",
" window_ms=20.0,\t\n",
" max_freq=None,\t\n",
" specgram_type='linear',\t\n",
" use_dB_normalization=True,\t\n",
" random_seed=0,\t\n",
" keep_transcription_text=False,\t\n",
" is_training=False,\t\n",
" batch_size=1,\t\n",
" num_workers=0,\t\n",
" sortagrad=False,\t\n",
" shuffle_method=None,\t\n",
" dist=False):\t\n",
"\n",
" dataset = DeepSpeech2Dataset(\t\n",
" manifest_path,\t\n",
" vocab_filepath,\t\n",
" mean_std_filepath,\t\n",
" augmentation_config=augmentation_config,\t\n",
" max_duration=max_duration,\t\n",
" min_duration=min_duration,\t\n",
" stride_ms=stride_ms,\t\n",
" window_ms=window_ms,\t\n",
" max_freq=max_freq,\t\n",
" specgram_type=specgram_type,\t\n",
" use_dB_normalization=use_dB_normalization,\t\n",
" random_seed=random_seed,\t\n",
" keep_transcription_text=keep_transcription_text)\t\n",
"\n",
" if dist:\t\n",
" batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n",
" dataset,\t\n",
" batch_size,\t\n",
" num_replicas=None,\t\n",
" rank=None,\t\n",
" shuffle=is_training,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
" else:\t\n",
" batch_sampler = DeepSpeech2BatchSampler(\t\n",
" dataset,\t\n",
" shuffle=is_training,\t\n",
" batch_size=batch_size,\t\n",
" drop_last=is_training,\t\n",
" sortagrad=is_training,\t\n",
" shuffle_method=shuffle_method)\t\n",
"\n",
" def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n",
" \"\"\"\t\n",
" Padding audio features with zeros to make them have the same shape (or\t\n",
" a user-defined shape) within one bach.\t\n",
"\n",
" If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n",
" as the target shape for padding. Otherwise, `padding_to` will be the\t\n",
" target shape (only refers to the second axis).\t\n",
"\n",
" If `flatten` is True, features will be flatten to 1darray.\t\n",
" \"\"\"\t\n",
" new_batch = []\t\n",
" # get target shape\t\n",
" max_length = max([audio.shape[1] for audio, text in batch])\t\n",
" if padding_to != -1:\t\n",
" if padding_to < max_length:\t\n",
" raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n",
" \"than any instance's shape in the batch\")\t\n",
" max_length = padding_to\t\n",
" max_text_length = max([len(text) for audio, text in batch])\t\n",
" # padding\t\n",
" padded_audios = []\t\n",
" audio_lens = []\t\n",
" texts, text_lens = [], []\t\n",
" for audio, text in batch:\t\n",
" padded_audio = np.zeros([audio.shape[0], max_length])\t\n",
" padded_audio[:, :audio.shape[1]] = audio\t\n",
" if flatten:\t\n",
" padded_audio = padded_audio.flatten()\t\n",
" padded_audios.append(padded_audio)\t\n",
" audio_lens.append(audio.shape[1])\t\n",
"\n",
" padded_text = np.zeros([max_text_length])\n",
" if is_training:\n",
" padded_text[:len(text)] = text\t# ids\n",
" else:\n",
" padded_text[:len(text)] = [ord(t) for t in text] # string\n",
" \n",
" texts.append(padded_text)\t\n",
" text_lens.append(len(text))\t\n",
"\n",
" padded_audios = np.array(padded_audios).astype('float32')\t\n",
" audio_lens = np.array(audio_lens).astype('int64')\t\n",
" texts = np.array(texts).astype('int32')\t\n",
" text_lens = np.array(text_lens).astype('int64')\t\n",
" return padded_audios, texts, audio_lens, text_lens\t\n",
"\n",
" loader = DataLoader(\t\n",
" dataset,\t\n",
" batch_sampler=batch_sampler,\t\n",
" collate_fn=partial(padding_batch, is_training=is_training),\t\n",
" num_workers=num_workers)\t\n",
" return loader"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('infer_manifest', str,\n",
" 'examples/aishell/data/manifest.dev',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/aishell/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/aishell/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/aishell/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'linear',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "bearing-physics",
"metadata": {},
"outputs": [],
"source": [
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" augmentation_config='{}',\n",
" #max_duration=float('inf'),\n",
" max_duration=27.0,\n",
" min_duration=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[22823, 26102, 20195, 37324, 0 , 0 ],\n",
" [22238, 26469, 23601, 22909, 0 , 0 ],\n",
" [20108, 26376, 22235, 26085, 0 , 0 ],\n",
" [36824, 35201, 20445, 25345, 32654, 24863],\n",
" [29042, 27748, 21463, 23456, 0 , 0 ]])\n",
"test raw 大时代里\n",
"test raw 煲汤受宠\n",
"audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 167, 180, 186, 186])\n",
"test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [4, 4, 4, 6, 4])\n",
"audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n",
" [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n",
" [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n",
" [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n",
" [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n",
" [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n",
" [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n",
" [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n",
" [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n",
" [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n",
" [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n",
" ...,\n",
" [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n",
" [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n",
" [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n",
" [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n",
" [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n",
" ...,\n",
" [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n",
" [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n",
" [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n",
"\n",
" [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n",
" [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n",
" [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n",
" ...,\n",
" [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n",
" [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n",
" [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len', audio_len)\n",
" print('test len', text_len)\n",
" print('audio', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "minus-modern",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}