refactor data

pull/578/head
Hui Zhang 5 years ago
parent 553aa35989
commit a7244593b9

@ -0,0 +1,511 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "downtown-invalid",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/DeepSpeech-2.x'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd ..\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "emerging-meter",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" def convert_to_list(value, n, name, dtype=np.int):\n"
]
}
],
"source": [
"import math\n",
"import random\n",
"import tarfile\n",
"import logging\n",
"import numpy as np\n",
"from collections import namedtuple\n",
"from functools import partial\n",
"\n",
"import paddle\n",
"from paddle.io import Dataset\n",
"from paddle.io import DataLoader\n",
"from paddle.io import BatchSampler\n",
"from paddle.io import DistributedBatchSampler\n",
"from paddle import distributed as dist\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "excessive-american",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "naval-brave",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"WARNING:root:register user softmax to paddle, remove this when fixed!\n",
"WARNING:root:register user log_softmax to paddle, remove this when fixed!\n",
"WARNING:root:register user sigmoid to paddle, remove this when fixed!\n",
"WARNING:root:register user log_sigmoid to paddle, remove this when fixed!\n",
"WARNING:root:register user relu to paddle, remove this when fixed!\n",
"WARNING:root:override cat of paddle if exists or register, remove this when fixed!\n",
"WARNING:root:override item of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:override long of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:override new_full of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:override eq of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n",
"WARNING:root:override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n",
"WARNING:root:register user view to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user view_as to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user fill_ to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user softmax to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user sigmoid to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user relu to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user type_as to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user to to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user float to paddle.Tensor, remove this when fixed!\n",
"WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n",
"WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n",
"WARNING:root:register user Module to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user ModuleList to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n",
"WARNING:root:register user export to paddle.jit, remove this when fixed!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n"
]
}
],
"source": [
"import sys\n",
"import argparse\n",
"import functools\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"parser = argparse.ArgumentParser(description=__doc__)\n",
"add_arg = functools.partial(add_arguments, argparser=parser)\n",
"# yapf: disable\n",
"add_arg('num_samples', int, 5, \"# of samples to infer.\")\n",
"add_arg('beam_size', int, 500, \"Beam search width.\")\n",
"add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n",
"add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n",
"add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n",
"add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n",
"add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n",
"add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n",
"add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n",
"add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n",
"add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n",
"add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n",
"add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n",
" \"bi-directional RNNs. Not for GRU.\")\n",
"add_arg('unit_type', str,\n",
" 'char',\n",
" \"Options: char, word, spm.\",\n",
" choices=['char', 'word', 'spm'])\n",
"add_arg('spm_model_prefix', str,\n",
" 'examples/tiny/s1/data/spm_bpe',\n",
" \"spm model prefix.\",)\n",
"add_arg('infer_manifest', str,\n",
" 'examples/tiny/s1/data/manifest.tiny',\n",
" \"Filepath of manifest to infer.\")\n",
"add_arg('mean_std_path', str,\n",
" 'examples/tiny/s1/data/mean_std.npz',\n",
" \"Filepath of normalizer's mean & std.\")\n",
"add_arg('vocab_path', str,\n",
" 'examples/tiny/s1/data/vocab.txt',\n",
" \"Filepath of vocabulary.\")\n",
"add_arg('lang_model_path', str,\n",
" 'models/lm/common_crawl_00.prune01111.trie.klm',\n",
" \"Filepath for language model.\")\n",
"add_arg('model_path', str,\n",
" 'examples/tiny/s1/checkpoints/step_final',\n",
" \"If None, the training starts from scratch, \"\n",
" \"otherwise, it resumes from the pre-trained model.\")\n",
"add_arg('decoding_method', str,\n",
" 'ctc_beam_search',\n",
" \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n",
" choices = ['ctc_beam_search', 'ctc_greedy'])\n",
"add_arg('error_rate_type', str,\n",
" 'wer',\n",
" \"Error rate type for evaluation.\",\n",
" choices=['wer', 'cer'])\n",
"add_arg('specgram_type', str,\n",
" 'fbank',\n",
" \"Audio feature type. Options: linear, mfcc.\",\n",
" choices=['linear', 'mfcc'])\n",
"add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n",
"add_arg('delta_delta', bool, False, \"delta delta\")\n",
"# yapf: disable\n",
"args = parser.parse_args([])\n",
"print(vars(args))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bearing-physics",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n",
" from numpy.dual import register_func\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" long_ = _make_signed(np.long)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" ulong = _make_unsigned(np.long)\n"
]
}
],
"source": [
"# batch_reader = create_dataloader(\n",
"# manifest_path=args.infer_manifest,\n",
"# vocab_filepath=args.vocab_path,\n",
"# mean_std_filepath=args.mean_std_path,\n",
"# augmentation_config='{}',\n",
"# #max_duration=float('inf'),\n",
"# max_duration=27.0,\n",
"# min_duration=0.0,\n",
"# stride_ms=10.0,\n",
"# window_ms=20.0,\n",
"# max_freq=None,\n",
"# specgram_type=args.specgram_type,\n",
"# use_dB_normalization=True,\n",
"# random_seed=0,\n",
"# keep_transcription_text=True,\n",
"# is_training=False,\n",
"# batch_size=args.num_samples,\n",
"# sortagrad=True,\n",
"# shuffle_method=None,\n",
"# dist=False)\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
"from deepspeech.frontend.speech import SpeechSegment\n",
"from deepspeech.frontend.normalizer import FeatureNormalizer\n",
"\n",
"\n",
"from deepspeech.io.collator import SpeechCollator\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"from deepspeech.io.sampler import (\n",
" SortagradDistributedBatchSampler,\n",
" SortagradBatchSampler,\n",
")\n",
"from deepspeech.io import create_dataloader\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=True,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "classified-melissa",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n",
"/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" if arr.dtype == np.object:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test: Tensor(shape=[5, 23], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[116, 104, 101, 32, 116, 119, 101, 110, 116, 105, 101, 115, -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ],\n",
" [119, 104, 101, 114, 101, 32, 105, 115, 32, 116, 104, 97, 116, -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ],\n",
" [116, 101, 110, 32, 115, 101, 99, 111, 110, 100, 115, -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ],\n",
" [104, 101, 32, 100, 111, 101, 115, 110, 39, 116, 32, 119, 111, 114, 107, 32, 97, 116, 32, 97, 108, 108, -1 ],\n",
" [119, 104, 101, 114, 101, 32, 105, 115, 32, 109, 121, 32, 98, 114, 111, 116, 104, 101, 114, 32, 110, 111, 119]])\n",
"test raw: the twenties\n",
"test raw: where is my brother now\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [12, 13, 11, 22, 23])\n",
"audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[-51.32406616, -17.91388321, 0.00000000 , ..., -26.66350746, -27.46039391, -27.22303963],\n",
" [-15.19027233, -20.52460480, 0.00000000 , ..., -28.47811317, -26.87953568, -25.13592339],\n",
" [-22.80181694, -19.48889351, 0.00000000 , ..., -29.96320724, -25.96619034, -24.57164192],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-15.38297653, -18.95307732, 0.00000000 , ..., -15.22777271, -16.46900940, -12.32327461],\n",
" [-14.06289291, -12.69954872, 0.00000000 , ..., -15.68012810, -16.92030334, -13.49134445],\n",
" [-19.78544235, -11.63046265, 0.00000000 , ..., -14.35409069, -14.82787228, -15.72653484],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-22.65289879, -21.11938667, 0.00000000 , ..., -31.80981827, -30.58669853, -28.68988228],\n",
" [-31.04699135, -21.68680763, 0.00000000 , ..., -29.90789604, -30.31726456, -30.99709320],\n",
" [-18.16406441, -17.50658417, 0.00000000 , ..., -29.47821617, -29.77137375, -30.45121002],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-16.17608452, -15.22302818, 0.00000000 , ..., -8.82944202 , -7.88900328 , -6.10806322 ],\n",
" [-19.40717316, -12.32932186, 0.00000000 , ..., -8.05214977 , -8.03145599 , -7.35137606 ],\n",
" [-11.01850796, -13.20147514, 0.00000000 , ..., -9.65334892 , -8.96987629 , -9.13897228 ],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-16.55369759, -16.95514297, 0.00000000 , ..., -7.00301647 , -6.53273058 , -10.14600754],\n",
" [-19.51947975, -14.86818218, 0.00000000 , ..., -6.82891273 , -6.22576237 , -9.42883873 ],\n",
" [-15.26447582, -22.26662445, 0.00000000 , ..., -13.31693172, -11.05612659, -12.70977211],\n",
" ...,\n",
" [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n",
" [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n",
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n"
]
}
],
"source": [
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "unexpected-skating",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "minus-modern",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test: Tensor(shape=[5, 23], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[87, 37, 26, 1, 87, 97, 26, 61, 87, 38, 26, 82, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n",
" [97, 37, 26, 79, 26, 1, 38, 82, 1, 87, 37, 3, 87, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n",
" [87, 26, 61, 1, 82, 26, 18, 64, 61, 25, 82, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n",
" [37, 26, 1, 25, 64, 26, 82, 61, 2, 87, 1, 97, 64, 79, 52, 1, 3, 87, 1, 3, 53, 53, -1],\n",
" [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n",
"test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n",
"test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [12, 13, 11, 22, 23])\n",
"audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [[[-51.32406616, -17.91388321, 0.00000000 , ..., -26.66350746, -27.46039391, -27.22303963],\n",
" [-15.19027233, -20.52460480, 0.00000000 , ..., -28.47811317, -26.87953568, -25.13592339],\n",
" [-22.80181694, -19.48889351, 0.00000000 , ..., -29.96320724, -25.96619034, -24.57164192],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-15.38297653, -18.95307732, 0.00000000 , ..., -15.22777271, -16.46900940, -12.32327461],\n",
" [-14.06289291, -12.69954872, 0.00000000 , ..., -15.68012810, -16.92030334, -13.49134445],\n",
" [-19.78544235, -11.63046265, 0.00000000 , ..., -14.35409069, -14.82787228, -15.72653484],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-22.65289879, -21.11938667, 0.00000000 , ..., -31.80981827, -30.58669853, -28.68988228],\n",
" [-31.04699135, -21.68680763, 0.00000000 , ..., -29.90789604, -30.31726456, -30.99709320],\n",
" [-18.16406441, -17.50658417, 0.00000000 , ..., -29.47821617, -29.77137375, -30.45121002],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-16.17608452, -15.22302818, 0.00000000 , ..., -8.82944202 , -7.88900328 , -6.10806322 ],\n",
" [-19.40717316, -12.32932186, 0.00000000 , ..., -8.05214977 , -8.03145599 , -7.35137606 ],\n",
" [-11.01850796, -13.20147514, 0.00000000 , ..., -9.65334892 , -8.96987629 , -9.13897228 ],\n",
" ...,\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n",
"\n",
" [[-16.55369759, -16.95514297, 0.00000000 , ..., -7.00301647 , -6.53273058 , -10.14600754],\n",
" [-19.51947975, -14.86818218, 0.00000000 , ..., -6.82891273 , -6.22576237 , -9.42883873 ],\n",
" [-15.26447582, -22.26662445, 0.00000000 , ..., -13.31693172, -11.05612659, -12.70977211],\n",
" ...,\n",
" [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n",
" [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n",
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n"
]
}
],
"source": [
"keep_transcription_text=False\n",
"batch_reader = create_dataloader(\n",
" manifest_path=args.infer_manifest,\n",
" unit_type=args.unit_type,\n",
" vocab_filepath=args.vocab_path,\n",
" mean_std_filepath=args.mean_std_path,\n",
" spm_model_prefix=args.spm_model_prefix,\n",
" augmentation_config='{}',\n",
" max_input_len=27.0,\n",
" min_input_len=0.0,\n",
" max_output_len=float('inf'),\n",
" min_output_len=0.0,\n",
" max_output_input_ratio=float('inf'),\n",
" min_output_input_ratio=0.0,\n",
" stride_ms=10.0,\n",
" window_ms=20.0,\n",
" max_freq=None,\n",
" specgram_type=args.specgram_type,\n",
" feat_dim=args.feat_dim,\n",
" delta_delta=args.delta_delta,\n",
" use_dB_normalization=True,\n",
" random_seed=0,\n",
" keep_transcription_text=keep_transcription_text,\n",
" is_training=False,\n",
" batch_size=args.num_samples,\n",
" num_workers=0,\n",
" sortagrad=True,\n",
" shuffle_method=None,\n",
" dist=False)\n",
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "chronic-diagram",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,648 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "choice-lender",
"metadata": {},
"outputs": [],
"source": [
"eng=\"one minute a voice said and the time buzzer sounded\"\n",
"chn=\"可控是病毒武器最基本的要求\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ruled-kuwait",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"o\n",
"n\n",
"e\n",
" \n",
"m\n",
"i\n",
"n\n",
"u\n",
"t\n",
"e\n",
" \n",
"a\n",
" \n",
"v\n",
"o\n",
"i\n",
"c\n",
"e\n",
" \n",
"s\n",
"a\n",
"i\n",
"d\n",
" \n",
"a\n",
"n\n",
"d\n",
" \n",
"t\n",
"h\n",
"e\n",
" \n",
"t\n",
"i\n",
"m\n",
"e\n",
" \n",
"b\n",
"u\n",
"z\n",
"z\n",
"e\n",
"r\n",
" \n",
"s\n",
"o\n",
"u\n",
"n\n",
"d\n",
"e\n",
"d\n"
]
}
],
"source": [
"for char in eng:\n",
" print(char)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "passive-petite",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"可\n",
"控\n",
"是\n",
"病\n",
"毒\n",
"武\n",
"器\n",
"最\n",
"基\n",
"本\n",
"的\n",
"要\n",
"求\n"
]
}
],
"source": [
"for char in chn:\n",
" print(char)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "olympic-realtor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"one\n",
"minute\n",
"a\n",
"voice\n",
"said\n",
"and\n",
"the\n",
"time\n",
"buzzer\n",
"sounded\n"
]
}
],
"source": [
"for word in eng.split():\n",
" print(word)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "induced-enhancement",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"可控是病毒武器最基本的要求\n"
]
}
],
"source": [
"for word in chn.split():\n",
" print(word)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "lovely-bottle",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'StringIO'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-3e4825b8299f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'StringIO'"
]
}
],
"source": [
"import StringIO"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "interested-cardiff",
"metadata": {},
"outputs": [],
"source": [
"from io import StringIO"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "portable-ivory",
"metadata": {},
"outputs": [],
"source": [
"inputs = StringIO()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "compatible-destination",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"64"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "federal-margin",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n",
"\n"
]
}
],
"source": [
"print(inputs.getvalue())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "consecutive-entity",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"64"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "desirable-anxiety",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n",
"nor is mister quilter's manner less interesting than his matter\n",
"\n"
]
}
],
"source": [
"print(inputs.getvalue())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "employed-schedule",
"metadata": {},
"outputs": [],
"source": [
"import tempfile"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "unlikely-honduras",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__ne__', '__new__', '__next__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_checkClosed', '_checkReadable', '_checkSeekable', '_checkWritable', '_dealloc_warn', '_finalizing', 'close', 'closed', 'detach', 'fileno', 'flush', 'isatty', 'mode', 'name', 'peek', 'raw', 'read', 'read1', 'readable', 'readinto', 'readinto1', 'readline', 'readlines', 'seek', 'seekable', 'tell', 'truncate', 'writable', 'write', 'writelines']\n",
"57\n"
]
}
],
"source": [
"with tempfile.TemporaryFile() as fp:\n",
" print(dir(fp))\n",
" print(fp.name)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "needed-trail",
"metadata": {},
"outputs": [],
"source": [
"a = tempfile.mkstemp(suffix=None, prefix='test', dir=None, text=False)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "hazardous-choir",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'count', 'index']\n"
]
}
],
"source": [
"print(dir(a))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "front-sauce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(57, '/tmp/test27smzbzc')\n"
]
}
],
"source": [
"print(a)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "shared-wages",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<built-in method index of tuple object at 0x7f999b525648>\n"
]
}
],
"source": [
"print(a.index)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "charged-carnival",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_closer', 'close', 'delete', 'file', 'name']\n",
"/tmp/tmpfjn7mygy\n"
]
}
],
"source": [
"fp= tempfile.NamedTemporaryFile(mode='w', delete=False)\n",
"print(dir(fp))\n",
"print(fp.name)\n",
"fp.close()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "religious-terror",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/tmpfjn7mygy\n"
]
}
],
"source": [
"import os\n",
"os.path.exists(fp.name)\n",
"print(fp.name)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "communist-gospel",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function BufferedRandom.write(buffer, /)>"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fp.write"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "simplified-clarity",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'example'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"s='/home/ubuntu/python/example.py'\n",
"os.path.splitext(os.path.basename(s))[0]"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "popular-genius",
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "studied-burner",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('hello', 1), ('world', 1)])\n"
]
}
],
"source": [
"counter = Counter()\n",
"counter.update([\"hello\"])\n",
"counter.update([\"world\"])\n",
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "mineral-ceremony",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('h', 1), ('e', 1), ('l', 3), ('o', 2), ('w', 1), ('r', 1), ('d', 1)])\n"
]
}
],
"source": [
"counter = Counter()\n",
"counter.update(\"hello\")\n",
"counter.update(\"world\")\n",
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "nonprofit-freedom",
"metadata": {},
"outputs": [],
"source": [
"counter.update(list(\"hello\"))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "extended-methodology",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_items([('h', 2), ('e', 2), ('l', 5), ('o', 3), ('w', 1), ('r', 1), ('d', 1)])\n"
]
}
],
"source": [
"print(counter.items())"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "grand-benjamin",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['h', 'e', 'l', 'l', 'o']"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(\"hello\")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "marine-fundamentals",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{}\n"
]
}
],
"source": [
"from io import StringIO\n",
"a = StringIO(initial_value='{}', newline='')\n",
"print(a.read())"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "suitable-charlotte",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "expected str, bytes or os.PathLike object, not _io.StringIO",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-56-4323a912120d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not _io.StringIO"
]
}
],
"source": [
"with io.open(a) as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "institutional-configuration",
"metadata": {},
"outputs": [],
"source": [
"io.open?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pregnant-modem",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -83,27 +83,11 @@ def inference(config, args):
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
dataset = ManifestDataset( config.data.manfiest = config.data.test_manifest
config.data.test_manifest, config.data.augmentation_config = io.StringIO(
config.data.unit_type, initial_value='{}', newline='')
config.data.vocab_filepath, config.data.keep_transcription_text = True
config.data.mean_std_filepath, dataset = ManifestDataset.from_config(config)
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
model = DeepSpeech2Model.from_pretrained(dataset, config, model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path) args.checkpoint_path)

@ -35,27 +35,12 @@ from deepspeech.io.dataset import ManifestDataset
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
dataset = ManifestDataset( config.data.manfiest = config.data.test_manifest
config.data.test_manifest, config.data.augmentation_config = io.StringIO(
config.data.unit_type, initial_value='{}', newline='')
config.data.vocab_filepath, config.data.keep_transcription_text = True
config.data.mean_std_filepath, dataset = ManifestDataset.from_config(config)
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
model = DeepSpeech2Model.from_pretrained(dataset, config, model = DeepSpeech2Model.from_pretrained(dataset, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()

@ -41,34 +41,18 @@ def tune(config, args):
if not args.num_betas >= 0: if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!") raise ValueError("num_betas must be non-negative!")
dev_dataset = ManifestDataset( config.data.manfiest = config.data.dev_manifest
config.data.dev_manifest, config.data.augmentation_config = io.StringIO(
config.data.unit_type, initial_value='{}', newline='')
config.data.vocab_filepath, config.data.keep_transcription_text = True
config.data.mean_std_filepath, dev_dataset = ManifestDataset.from_config(config)
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
valid_loader = DataLoader( valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(is_training=False)) collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config, model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
args.checkpoint_path) args.checkpoint_path)

@ -145,52 +145,15 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config config = self.config
config.data.keep_transcription_text = False
train_dataset = ManifestDataset( config.data.manfiest = config.data.train_manifest
config.data.train_manifest, train_dataset = ManifestDataset.from_config(config)
config.data.unit_type,
config.data.vocab_filepath, config.data.manfiest = config.data.dev_manifest
config.data.mean_std_filepath, config.data.augmentation_config = io.StringIO(
spm_model_prefix=config.data.spm_model_prefix, initial_value='{}', newline='')
augmentation_config=io.open( dev_dataset = ManifestDataset.from_config(config)
config.data.augmentation_config, mode='r',
encoding='utf8').read(),
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=False)
dev_dataset = ManifestDataset(
config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=False)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
@ -211,7 +174,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad, sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(is_training=True) collate_fn = SpeechCollator(keep_transcription_text=False)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
@ -367,27 +330,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config config = self.config
# return raw text # return raw text
test_dataset = ManifestDataset(
config.data.test_manifest, config.data.manfiest = config.data.test_manifest
config.data.unit_type, config.data.augmentation_config = io.StringIO(
config.data.vocab_filepath, initial_value='{}', newline='')
config.data.mean_std_filepath, config.data.keep_transcription_text = True
spm_model_prefix=config.data.spm_model_prefix, test_dataset = ManifestDataset.from_config(config)
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
# return text ord id # return text ord id
self.test_loader = DataLoader( self.test_loader = DataLoader(
@ -395,7 +343,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(is_training=False)) collate_fn=SpeechCollator(keep_transcription_text=True))
self.logger.info("Setup test Dataloader!") self.logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):

@ -42,6 +42,7 @@ class SpeechSegment(AudioSegment):
""" """
AudioSegment.__init__(self, samples, sample_rate) AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript self._transcript = transcript
# must init `tokens` with `token_ids` at the same time
self._tokens = tokens self._tokens = tokens
self._token_ids = token_ids self._token_ids = token_ids
@ -183,7 +184,7 @@ class SpeechSegment(AudioSegment):
@property @property
def has_token(self): def has_token(self):
if self._tokens or self._token_ids: if self._tokens and self._token_ids:
return True return True
return False return False

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
@ -26,12 +28,18 @@ def create_dataloader(manifest_path,
mean_std_filepath, mean_std_filepath,
spm_model_prefix, spm_model_prefix,
augmentation_config='{}', augmentation_config='{}',
max_duration=float('inf'), max_input_len=float('inf'),
min_duration=0.0, min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', specgram_type='linear',
feat_dim=None,
delta_delta=False,
use_dB_normalization=True, use_dB_normalization=True,
random_seed=0, random_seed=0,
keep_transcription_text=False, keep_transcription_text=False,
@ -43,20 +51,24 @@ def create_dataloader(manifest_path,
dist=False): dist=False):
dataset = ManifestDataset( dataset = ManifestDataset(
manifest_path, manifest_path=manifest_path,
unit_type, unit_type=unit_type,
vocab_filepath, vocab_filepath=vocab_filepath,
mean_std_filepath, mean_std_filepath=mean_std_filepath,
spm_model_prefix=spm_model_prefix, spm_model_prefix=spm_model_prefix,
augmentation_config=augmentation_config, augmentation_config=augmentation_config,
max_duration=max_duration, max_input_len=max_input_len,
min_duration=min_duration, min_input_len=min_input_len,
max_output_len=max_output_len,
min_output_len=min_output_len,
max_output_input_ratio=max_output_input_ratio,
min_output_input_ratio=min_output_input_ratio,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq, max_freq=max_freq,
specgram_type=specgram_type, specgram_type=specgram_type,
feat_dim=config.data.feat_dim, feat_dim=feat_dim,
delta_delta=config.data.delat_delta, delta_delta=delta_delta,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
random_seed=random_seed, random_seed=random_seed,
keep_transcription_text=keep_transcription_text) keep_transcription_text=keep_transcription_text)
@ -80,7 +92,10 @@ def create_dataloader(manifest_path,
sortagrad=is_training, sortagrad=is_training,
shuffle_method=shuffle_method) shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): def padding_batch(batch,
padding_to=-1,
flatten=False,
keep_transcription_text=True):
""" """
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach. a user-defined shape) within one bach.
@ -113,10 +128,10 @@ def create_dataloader(manifest_path,
audio_lens.append(audio.shape[1]) audio_lens.append(audio.shape[1])
padded_text = np.zeros([max_text_length]) padded_text = np.zeros([max_text_length])
if is_training: if keep_transcription_text:
padded_text[:len(text)] = text #ids
else:
padded_text[:len(text)] = [ord(t) for t in text] # string padded_text[:len(text)] = [ord(t) for t in text] # string
else:
padded_text[:len(text)] = text #ids
texts.append(padded_text) texts.append(padded_text)
text_lens.append(len(text)) text_lens.append(len(text))
@ -124,11 +139,13 @@ def create_dataloader(manifest_path,
audio_lens = np.array(audio_lens).astype('int64') audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32') texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64') text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens return padded_audios, audio_lens, texts, text_lens
#collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
collate_fn = SpeechCollator(keep_transcription_text=keep_transcription_text)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial(padding_batch, is_training=is_training), collate_fn=collate_fn,
num_workers=num_workers) num_workers=num_workers)
return loader return loader

@ -25,14 +25,14 @@ __all__ = ["SpeechCollator"]
class SpeechCollator(): class SpeechCollator():
def __init__(self, is_training=True): def __init__(self, keep_transcription_text=True):
""" """
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach. a user-defined shape) within one bach.
if ``is_training`` is True, text is token ids else is raw string. if ``keep_transcription_text`` is False, text is token ids else is raw string.
""" """
self._is_training = is_training self._keep_transcription_text = keep_transcription_text
def __call__(self, batch): def __call__(self, batch):
"""batch examples """batch examples
@ -61,15 +61,15 @@ class SpeechCollator():
# for training, text is token ids # for training, text is token ids
# else text is string, convert to unicode ord # else text is string, convert to unicode ord
tokens = [] tokens = []
if self._is_training: if self._keep_transcription_text:
tokens = text # token ids assert isinstance(text, str), type(text)
else:
assert isinstance(text, str)
tokens = [ord(t) for t in text] tokens = [ord(t) for t in text]
else:
tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array( tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64) tokens, dtype=np.int64)
texts.append(tokens) texts.append(tokens)
text_lens.append(len(text)) text_lens.append(tokens.shape[0])
padded_audios = pad_sequence( padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D] audios, padding_value=0.0).astype(np.float32) #[B, T, D]

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import math import math
import random import random
import tarfile import tarfile
@ -43,8 +44,12 @@ class ManifestDataset(Dataset):
mean_std_filepath, mean_std_filepath,
spm_model_prefix=None, spm_model_prefix=None,
augmentation_config='{}', augmentation_config='{}',
max_duration=float('inf'), max_input_len=float('inf'),
min_duration=0.0, min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
n_fft=None, n_fft=None,
@ -66,8 +71,12 @@ class ManifestDataset(Dataset):
mean_std_filepath (str): mean and std file path, which suffix is *.npy mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'. augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf'). max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0. min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0. stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0. window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None. n_fft (int, optional): fft points for rfft. Defaults to None.
@ -82,9 +91,13 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
""" """
super().__init__() super().__init__()
self._max_input_len = max_input_len,
self._min_input_len = min_input_len,
self._max_output_len = max_output_len,
self._min_output_len = min_output_len,
self._max_output_input_ratio = max_output_input_ratio,
self._min_output_input_ratio = min_output_input_ratio,
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath) self._normalizer = FeatureNormalizer(mean_std_filepath)
self._audio_augmentation_pipeline = AugmentationPipeline( self._audio_augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed) augmentation_config=augmentation_config, random_seed=random_seed)
@ -102,6 +115,7 @@ class ManifestDataset(Dataset):
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB) target_dB=target_dB)
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
# for caching tar files info # for caching tar files info
@ -112,9 +126,58 @@ class ManifestDataset(Dataset):
# read manifest # read manifest
self._manifest = read_manifest( self._manifest = read_manifest(
manifest_path=manifest_path, manifest_path=manifest_path,
max_duration=self._max_duration, max_input_len=max_input_len,
min_duration=self._min_duration) min_input_len=min_input_len,
self._manifest.sort(key=lambda x: x["duration"]) max_output_len=max_output_len,
min_output_len=min_output_len,
max_output_input_ratio=max_output_input_ratio,
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
@classmethod
def from_config(cls, config):
"""Build a ManifestDataset object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
ManifestDataset: dataet object.
"""
assert manifest in config.data
assert keep_transcription_text in config.data
if isinstance(config.data.augmentation_config, (str, bytes)):
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls(
manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delat_delta,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
return dataset
@property @property
def manifest(self): def manifest(self):

@ -1,146 +0,0 @@
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
# 2014 David Snyder
# This script combines the data from multiple source directories into
# a single destination directory.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
# about what these directories contain.
# Begin configuration section.
extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
skip_fix=false # skip the fix_data_dir.sh in the end
# End configuration section.
echo "$0 $@" # Print the command line for logging
if [ -f path.sh ]; then . ./path.sh; fi
if [ -f parse_options.sh ]; then . parse_options.sh || exit 1; fi
if [ $# -lt 2 ]; then
echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
echo "Note, files that don't appear in all source dirs will not be combined,"
echo "with the exception of utt2uniq and segments, which are created where necessary."
exit 1
fi
dest=$1;
shift;
first_src=$1;
rm -r $dest 2>/dev/null
mkdir -p $dest;
export LC_ALL=C
for dir in $*; do
if [ ! -f $dir/utt2spk ]; then
echo "$0: no such file $dir/utt2spk"
exit 1;
fi
done
# Check that frame_shift are compatible, where present together with features.
dir_with_frame_shift=
for dir in $*; do
if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
if [[ $dir_with_frame_shift ]] &&
! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
echo "$0:error: different frame_shift in directories $dir and " \
"$dir_with_frame_shift. Cannot combine features."
exit 1;
fi
dir_with_frame_shift=$dir
fi
done
# W.r.t. utt2uniq file the script has different behavior compared to other files
# it is not compulsary for it to exist in src directories, but if it exists in
# even one it should exist in all. We will create the files where necessary
has_utt2uniq=false
for in_dir in $*; do
if [ -f $in_dir/utt2uniq ]; then
has_utt2uniq=true
break
fi
done
if $has_utt2uniq; then
# we are going to create an utt2uniq file in the destdir
for in_dir in $*; do
if [ ! -f $in_dir/utt2uniq ]; then
# we assume that utt2uniq is a one to one mapping
cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
else
cat $in_dir/utt2uniq
fi
done | sort -k1 > $dest/utt2uniq
echo "$0: combined utt2uniq"
else
echo "$0 [info]: not combining utt2uniq as it does not exist"
fi
# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
# segments are treated similarly to utt2uniq. If it exists in some, but not all
# src directories, then we generate segments where necessary.
has_segments=false
for in_dir in $*; do
if [ -f $in_dir/segments ]; then
has_segments=true
break
fi
done
if $has_segments; then
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
utils/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
done | sort -k1 > $dest/segments
echo "$0: combined segments"
else
echo "$0 [info]: not combining segments as it does not exist"
fi
for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
exists_somewhere=false
absent_somewhere=false
for d in $*; do
if [ -f $d/$file ]; then
exists_somewhere=true
else
absent_somewhere=true
fi
done
if ! $absent_somewhere; then
set -o pipefail
( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
set +o pipefail
echo "$0: combined $file"
else
if ! $exists_somewhere; then
echo "$0 [info]: not combining $file as it does not exist"
else
echo "$0 [info]: **not combining $file as it does not exist everywhere**"
fi
fi
done
tools/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
tools/fix_data_dir.sh $dest || exit 1;
fi
exit 0

@ -1,104 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
$ignore_oov = 0;
for($x = 0; $x < 2; $x++) {
if ($ARGV[0] eq "--map-oov") {
shift @ARGV;
$map_oov = shift @ARGV;
if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") {
# disallow '-f', the empty string and anything ending in words.txt as the
# OOV symbol because these are likely command-line errors.
die "the --map-oov option requires an argument";
}
}
if ($ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}
}
$symtab = shift @ARGV;
if (!defined $symtab) {
print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" .
"options: [--map-oov <oov-symbol> ] [-f <field-range> ]\n" .
"note: <field-range> can look like 4-5, or 4-, or 5-, or 1.\n";
}
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "bad line in symbol table file: $_";
$sym2int{$A[0]} = $A[1] + 0;
}
if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up
if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; }
$map_oov = $sym2int{$map_oov};
}
$num_warning = 0;
$max_warning = 20;
while (<>) {
@A = split(" ", $_);
@B = ();
for ($n = 0; $n < @A; $n++) {
$a = $A[$n];
if ( (!defined $field_begin || $n >= $field_begin)
&& (!defined $field_end || $n <= $field_end)) {
$i = $sym2int{$a};
if (!defined ($i)) {
if (defined $map_oov) {
if ($num_warning++ < $max_warning) {
print STDERR "sym2int.pl: replacing $a with $map_oov\n";
if ($num_warning == $max_warning) {
print STDERR "sym2int.pl: not warning for OOVs any more times\n";
}
}
$i = $map_oov;
} else {
$pos = $n+1;
die "sym2int.pl: undefined symbol $a (in position $pos)\n";
}
}
$a = $i;
}
push @B, $a;
}
print join(" ", @B);
print "\n";
}
if ($num_warning > 0) {
print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n";
}
exit(0);
Loading…
Cancel
Save