From 5fbced8b52ad602bd6572b4216e2e5ec086d6abb Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 16 Aug 2021 11:44:48 +0000 Subject: [PATCH] more data utils --- deepspeech/io/batchfy.py | 1 - deepspeech/io/dataloader.py | 177 +++++++++++++++++++++ deepspeech/io/dataset.py | 26 ++- deepspeech/io/utility.py | 305 +++++++++++++++++++++++++++++++++++- 4 files changed, 506 insertions(+), 3 deletions(-) create mode 100644 deepspeech/io/dataloader.py diff --git a/deepspeech/io/batchfy.py b/deepspeech/io/batchfy.py index 31fa2392b..d237eb749 100644 --- a/deepspeech/io/batchfy.py +++ b/deepspeech/io/batchfy.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools -import logger import numpy as np from deepspeech.utils.log import Log diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py new file mode 100644 index 000000000..0c5034caa --- /dev/null +++ b/deepspeech/io/dataloader.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.io import DataLoader + +from deepspeech.frontend.utility import read_manifest +from deepspeech.io.batchfy import make_batchset +from deepspeech.io.dataset import TransformDataset +from deepspeech.io.utility import LoadInputsAndTargets +from deepspeech.io.utility import pad_list +from deepspeech.utils.log import Log + +__all__ = ["CustomConverter", "BatchDataLoader"] + +logger = Log(__name__).getlog() + + +class CustomConverter(): + """Custom batch converter. + + Args: + subsampling_factor (int): The subsampling factor. + dtype (paddle.dtype): Data type to convert. + + """ + + def __init__(self, subsampling_factor=1, dtype=paddle.float32): + """Construct a CustomConverter object.""" + self.subsampling_factor = subsampling_factor + self.ignore_id = -1 + self.dtype = dtype + + def __call__(self, batch): + """Transform a batch and send it to a device. + + Args: + batch (list): The batch to transform. + + Returns: + tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor) + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[::self.subsampling_factor, :] for x in xs] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + + # perform padding and convert to tensor + # currently only support real number + if xs[0].dtype.kind == "c": + xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype) + xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype) + # Note(kamo): + # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. + # Don't create ComplexTensor and give it E2E here + # because torch.nn.DataParellel can't handle it. + xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} + else: + xs_pad = pad_list(xs, 0).astype(self.dtype) + + ilens = paddle.to_tensor(ilens) + + # NOTE: this is for multi-output (e.g., speech translation) + ys_pad = pad_list( + [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], + self.ignore_id) + + olens = np.array([y.shape[0] for y in ys]) + return xs_pad, ilens, ys_pad, olens + + +class BatchDataLoader(): + def __init__(self, + json_file: str, + train_mode: bool, + sortagrad: bool=False, + batch_size: int=0, + maxlen_in: float=float('inf'), + maxlen_out: float=float('inf'), + minibatches: int=0, + mini_batch_size: int=1, + batch_count: str='auto', + batch_bins: int=0, + batch_frames_in: int=0, + batch_frames_out: int=0, + batch_frames_inout: int=0, + preprocess_conf=None, + n_iter_processes: int=1, + subsampling_factor: int=1, + num_encs: int=1): + self.json_file = json_file + self.train_mode = train_mode + + self.use_sortagrad = sortagrad == -1 or sortagrad > 0 + self.batch_size = batch_size + self.maxlen_in = maxlen_in + self.maxlen_out = maxlen_out + self.batch_count = batch_count + self.batch_bins = batch_bins + self.batch_frames_in = batch_frames_in + self.batch_frames_out = batch_frames_out + self.batch_frames_inout = batch_frames_inout + + self.subsampling_factor = subsampling_factor + self.num_encs = num_encs + self.preprocess_conf = preprocess_conf + + self.n_iter_processes = n_iter_processes + + # read json data + data_json = read_manifest(json_file) + logger.info(f"load {json_file} file.") + + # make minibatch list (variable length) + self.data = make_batchset( + data_json, + batch_size, + maxlen_in, + maxlen_out, + minibatches, # for debug + min_batch_size=mini_batch_size, + shortest_first=self.use_sortagrad, + count=batch_count, + batch_bins=batch_bins, + batch_frames_in=batch_frames_in, + batch_frames_out=batch_frames_out, + batch_frames_inout=batch_frames_inout, + iaxis=0, + oaxis=0, ) + logger.info(f"batchfy data {json_file}: {len(self.data)}.") + + self.load = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=preprocess_conf, + preprocess_args={"train": + train_mode}, # Switch the mode of preprocessing + ) + + # Setup a converter + if num_encs == 1: + self.converter = CustomConverter( + subsampling_factor=subsampling_factor, dtype=dtype) + else: + assert NotImplementedError("not impl CustomConverterMulEnc.") + + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + # default collate function converts numpy array to pytorch tensor + # we used an empty collate function instead which returns list + self.train_loader = DataLoader( + dataset=TransformDataset( + self.data, lambda data: self.converter([self.load(data)])), + batch_size=1, + shuffle=not use_sortagrad if train_mode else False, + collate_fn=lambda x: x[0], + num_workers=n_iter_processes, ) + logger.info(f"dataloader for {json_file}.") + + def __repr__(self): + return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}" diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index ac7be1f9e..e2db93404 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log -__all__ = ["ManifestDataset", "TripletManifestDataset"] +__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"] logger = Log(__name__).getlog() @@ -116,3 +116,27 @@ class TripletManifestDataset(ManifestDataset): instance = self._manifest[idx] return instance["utt"], instance["feat"], instance["text"], instance[ "text1"] + + +class TransformDataset(Dataset): + """Transform Dataset. + + Args: + data: list object from make_batchset + transfrom: transform function + + """ + + def __init__(self, data, transform): + """Init function.""" + super().__init__() + self.data = data + self.transform = transform + + def __len__(self): + """Len function.""" + return len(self.data) + + def __getitem__(self, idx): + """[] operator.""" + return self.transform(self.data[idx]) diff --git a/deepspeech/io/utility.py b/deepspeech/io/utility.py index 0cd37428b..915813f3a 100644 --- a/deepspeech/io/utility.py +++ b/deepspeech/io/utility.py @@ -11,17 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from typing import List import numpy as np +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.utils.log import Log -__all__ = ["pad_sequence"] +__all__ = ["pad_list", "pad_sequence", "LoadInputsAndTargets"] logger = Log(__name__).getlog() +def pad_list(sequences: List[np.ndarray], + padding_value: float=0.0) -> np.ndarray: + return pad_sequence(sequences, True, padding_value) + + def pad_sequence(sequences: List[np.ndarray], batch_first: bool=True, padding_value: float=0.0) -> np.ndarray: @@ -80,3 +87,299 @@ def pad_sequence(sequences: List[np.ndarray], out_tensor[:length, i, ...] = tensor return out_tensor + + +class LoadInputsAndTargets(): + """Create a mini-batch from a list of dicts + + >>> batch = [('utt1', + ... dict(input=[dict(feat='some.ark:123', + ... filetype='mat', + ... name='input1', + ... shape=[100, 80])], + ... output=[dict(tokenid='1 2 3 4', + ... name='target1', + ... shape=[4, 31])]])) + >>> l = LoadInputsAndTargets() + >>> feat, target = l(batch) + + :param: str mode: Specify the task mode, "asr" or "tts" + :param: str preprocess_conf: The path of a json file for pre-processing + :param: bool load_input: If False, not to load the input data + :param: bool load_output: If False, not to load the output data + :param: bool sort_in_input_length: Sort the mini-batch in descending order + of the input length + :param: bool use_speaker_embedding: Used for tts mode only + :param: bool use_second_target: Used for tts mode only + :param: dict preprocess_args: Set some optional arguments for preprocessing + :param: Optional[dict] preprocess_args: Used for tts mode only + """ + + def __init__( + self, + mode="asr", + preprocess_conf=None, + load_input=True, + load_output=True, + sort_in_input_length=True, + preprocess_args=None, + keep_all_data_on_mem=False, ): + self._loaders = {} + + if mode not in ["asr"]: + raise ValueError("Only asr are allowed: mode={}".format(mode)) + + if preprocess_conf is not None: + self.preprocessing = AugmentationPipeline(preprocess_conf) + logging.warning( + "[Experimental feature] Some preprocessing will be done " + "for the mini-batch creation using {}".format( + self.preprocessing)) + else: + # If conf doesn't exist, this function don't touch anything. + self.preprocessing = None + + self.mode = mode + self.load_output = load_output + self.load_input = load_input + self.sort_in_input_length = sort_in_input_length + if preprocess_args is None: + self.preprocess_args = {} + else: + assert isinstance(preprocess_args, dict), type(preprocess_args) + self.preprocess_args = dict(preprocess_args) + + self.keep_all_data_on_mem = keep_all_data_on_mem + + def __call__(self, batch, return_uttid=False): + """Function to load inputs and targets from list of dicts + + :param List[Tuple[str, dict]] batch: list of dict which is subset of + loaded data.json + :param bool return_uttid: return utterance ID information for visualization + :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)] + :return: list of input feature sequences + [(T_1, D), (T_2, D), ..., (T_B, D)] + :rtype: list of float ndarray + :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)] + :rtype: list of int ndarray + + """ + x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] + y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] + uttid_list = [] # List[str] + + for uttid, info in batch: + uttid_list.append(uttid) + + if self.load_input: + # Note(kamo): This for-loop is for multiple inputs + for idx, inp in enumerate(info["input"]): + # {"input": + # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # "name": "input1", ...}], ...} + x = self._get_from_loader( + filepath=inp["feat"], + filetype=inp.get("filetype", "mat")) + x_feats_dict.setdefault(inp["name"], []).append(x) + + if self.load_output: + for idx, inp in enumerate(info["output"]): + if "tokenid" in inp: + # ======= Legacy format for output ======= + # {"output": [{"tokenid": "1 2 3 4"}]) + x = np.fromiter( + map(int, inp["tokenid"].split()), dtype=np.int64) + else: + # ======= New format ======= + # {"input": + # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # "name": "target1", ...}], ...} + x = self._get_from_loader( + filepath=inp["feat"], + filetype=inp.get("filetype", "mat")) + + y_feats_dict.setdefault(inp["name"], []).append(x) + + if self.mode == "asr": + return_batch, uttid_list = self._create_batch_asr( + x_feats_dict, y_feats_dict, uttid_list) + else: + raise NotImplementedError(self.mode) + + if self.preprocessing is not None: + # Apply pre-processing all input features + for x_name in return_batch.keys(): + if x_name.startswith("input"): + return_batch[x_name] = self.preprocessing( + return_batch[x_name], uttid_list, + **self.preprocess_args) + + if return_uttid: + return tuple(return_batch.values()), uttid_list + + # Doesn't return the names now. + return tuple(return_batch.values()) + + def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list): + """Create a OrderedDict for the mini-batch + + :param OrderedDict x_feats_dict: + e.g. {"input1": [ndarray, ndarray, ...], + "input2": [ndarray, ndarray, ...]} + :param OrderedDict y_feats_dict: + e.g. {"target1": [ndarray, ndarray, ...], + "target2": [ndarray, ndarray, ...]} + :param: List[str] uttid_list: + Give uttid_list to sort in the same order as the mini-batch + :return: batch, uttid_list + :rtype: Tuple[OrderedDict, List[str]] + """ + # handle single-input and multi-input (paralell) asr mode + xs = list(x_feats_dict.values()) + + if self.load_output: + ys = list(y_feats_dict.values()) + assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0])) + + # get index of non-zero length samples + nonzero_idx = list( + filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0])))) + for n in range(1, len(y_feats_dict)): + nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx) + else: + # Note(kamo): Be careful not to make nonzero_idx to a generator + nonzero_idx = list(range(len(xs[0]))) + + if self.sort_in_input_length: + # sort in input lengths based on the first input + nonzero_sorted_idx = sorted( + nonzero_idx, key=lambda i: -len(xs[0][i])) + else: + nonzero_sorted_idx = nonzero_idx + + if len(nonzero_sorted_idx) != len(xs[0]): + logging.warning( + "Target sequences include empty tokenid (batch {} -> {}).". + format(len(xs[0]), len(nonzero_sorted_idx))) + + # remove zero-length samples + xs = [[x[i] for i in nonzero_sorted_idx] for x in xs] + uttid_list = [uttid_list[i] for i in nonzero_sorted_idx] + + x_names = list(x_feats_dict.keys()) + if self.load_output: + ys = [[y[i] for i in nonzero_sorted_idx] for y in ys] + y_names = list(y_feats_dict.keys()) + + # Keeping x_name and y_name, e.g. input1, for future extension + return_batch = OrderedDict([ + * [(x_name, x) for x_name, x in zip(x_names, xs)], + * [(y_name, y) for y_name, y in zip(y_names, ys)], + ]) + else: + return_batch = OrderedDict( + [(x_name, x) for x_name, x in zip(x_names, xs)]) + return return_batch, uttid_list + + def _get_from_loader(self, filepath, filetype): + """Return ndarray + + In order to make the fds to be opened only at the first referring, + the loader are stored in self._loaders + + >>> ndarray = loader.get_from_loader( + ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5') + + :param: str filepath: + :param: str filetype: + :return: + :rtype: np.ndarray + """ + if filetype == "hdf5": + # e.g. + # {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL" + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = h5py.File(filepath, "r") + self._loaders[filepath] = loader + return loader[key][()] + elif filetype == "sound.hdf5": + # e.g. + # {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "sound.hdf5", + # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL" + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = SoundHDF5File(filepath, "r", dtype="int16") + self._loaders[filepath] = loader + array, rate = loader[key] + return array + elif filetype == "sound": + # e.g. + # {"input": [{"feat": "some/path.wav", + # "filetype": "sound"}, + # Assume PCM16 + if not self.keep_all_data_on_mem: + array, _ = soundfile.read(filepath, dtype="int16") + return array + if filepath not in self._loaders: + array, _ = soundfile.read(filepath, dtype="int16") + self._loaders[filepath] = array + return self._loaders[filepath] + elif filetype == "npz": + # e.g. + # {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL", + # "filetype": "npz", + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = np.load(filepath) + self._loaders[filepath] = loader + return loader[key] + elif filetype == "npy": + # e.g. + # {"input": [{"feat": "some/path.npy", + # "filetype": "npy"}, + if not self.keep_all_data_on_mem: + return np.load(filepath) + if filepath not in self._loaders: + self._loaders[filepath] = np.load(filepath) + return self._loaders[filepath] + elif filetype in ["mat", "vec"]: + # e.g. + # {"input": [{"feat": "some/path.ark:123", + # "filetype": "mat"}]}, + # In this case, "123" indicates the starting points of the matrix + # load_mat can load both matrix and vector + if not self.keep_all_data_on_mem: + return kaldiio.load_mat(filepath) + if filepath not in self._loaders: + self._loaders[filepath] = kaldiio.load_mat(filepath) + return self._loaders[filepath] + elif filetype == "scp": + # e.g. + # {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL", + # "filetype": "scp", + filepath, key = filepath.split(":", 1) + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = kaldiio.load_scp(filepath) + self._loaders[filepath] = loader + return loader[key] + else: + raise NotImplementedError( + "Not supported: loader_type={}".format(filetype))