diff --git a/deepspeech/io/batchfy.py b/deepspeech/io/batchfy.py index 36c1ec31..54c6f0e1 100644 --- a/deepspeech/io/batchfy.py +++ b/deepspeech/io/batchfy.py @@ -421,7 +421,7 @@ def make_batchset( key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), reverse=not shortest_first, ) logger.info("# utts: " + str(len(sorted_data))) - + if count == "seq": batches = batchfy_by_seq( sorted_data, @@ -466,4 +466,4 @@ def make_batchset( logger.info("# minibatches: " + str(len(batches))) # batch: List[List[Tuple[str, dict]]] - return batches \ No newline at end of file + return batches diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 2ef11966..4900350e 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -23,7 +23,7 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import IGNORE_ID -from deepspeech.io.utility import pad_sequence +from deepspeech.io.utility import pad_list from deepspeech.utils.log import Log __all__ = ["SpeechCollator"] @@ -286,13 +286,12 @@ class SpeechCollator(): texts.append(tokens) text_lens.append(tokens.shape[0]) - padded_audios = pad_sequence( - audios, padding_value=0.0).astype(np.float32) #[B, T, D] - audio_lens = np.array(audio_lens).astype(np.int64) - padded_texts = pad_sequence( - texts, padding_value=IGNORE_ID).astype(np.int64) - text_lens = np.array(text_lens).astype(np.int64) - return utts, padded_audios, audio_lens, padded_texts, text_lens + #[B, T, D] + xs_pad = pad_list(audios, 0.0).astype(np.float32) + ilens = np.array(audio_lens).astype(np.int64) + ys_pad = pad_list(texts, IGNORE_ID).astype(np.int64) + olens = np.array(text_lens).astype(np.int64) + return utts, xs_pad, ilens, ys_pad, olens @property def manifest(self): diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index 0c5034ca..2e6b6a02 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np from paddle.io import DataLoader from deepspeech.frontend.utility import read_manifest @@ -30,11 +31,11 @@ class CustomConverter(): Args: subsampling_factor (int): The subsampling factor. - dtype (paddle.dtype): Data type to convert. - + dtype (np.dtype): Data type to convert. + """ - def __init__(self, subsampling_factor=1, dtype=paddle.float32): + def __init__(self, subsampling_factor=1, dtype=np.float32): """Construct a CustomConverter object.""" self.subsampling_factor = subsampling_factor self.ignore_id = -1 @@ -52,7 +53,7 @@ class CustomConverter(): """ # batch should be located in list assert len(batch) == 1 - xs, ys = batch[0] + (xs, ys), utts = batch[0] # perform subsampling if self.subsampling_factor > 1: @@ -74,15 +75,14 @@ class CustomConverter(): else: xs_pad = pad_list(xs, 0).astype(self.dtype) - ilens = paddle.to_tensor(ilens) - # NOTE: this is for multi-output (e.g., speech translation) ys_pad = pad_list( [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], self.ignore_id) - olens = np.array([y.shape[0] for y in ys]) - return xs_pad, ilens, ys_pad, olens + olens = np.array( + [y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]) + return utts, xs_pad, ilens, ys_pad, olens class BatchDataLoader(): @@ -166,7 +166,7 @@ class BatchDataLoader(): # we used an empty collate function instead which returns list self.train_loader = DataLoader( dataset=TransformDataset( - self.data, lambda data: self.converter([self.load(data)])), + self.data, lambda data: self.converter([self.load(data, return_uttid=True)])), batch_size=1, shuffle=not use_sortagrad if train_mode else False, collate_fn=lambda x: x[0], diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index a30666b4..c5b6e737 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -16,7 +16,6 @@ from typing import Optional from paddle.io import Dataset from yacs.config import CfgNode - from deepspeech.utils.log import Log __all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"] diff --git a/deepspeech/io/utility.py b/deepspeech/io/utility.py index 915813f3..91abdf08 100644 --- a/deepspeech/io/utility.py +++ b/deepspeech/io/utility.py @@ -14,7 +14,9 @@ from collections import OrderedDict from typing import List +import kaldiio import numpy as np +import soundfile from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.utils.log import Log @@ -383,3 +385,91 @@ class LoadInputsAndTargets(): else: raise NotImplementedError( "Not supported: loader_type={}".format(filetype)) + + +class SoundHDF5File(): + """Collecting sound files to a HDF5 file + + >>> f = SoundHDF5File('a.flac.h5', mode='a') + >>> array = np.random.randint(0, 100, 100, dtype=np.int16) + >>> f['id'] = (array, 16000) + >>> array, rate = f['id'] + + + :param: str filepath: + :param: str mode: + :param: str format: The type used when saving wav. flac, nist, htk, etc. + :param: str dtype: + + """ + + def __init__(self, + filepath, + mode="r+", + format=None, + dtype="int16", + **kwargs): + self.filepath = filepath + self.mode = mode + self.dtype = dtype + + self.file = h5py.File(filepath, mode, **kwargs) + if format is None: + # filepath = a.flac.h5 -> format = flac + second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1] + format = second_ext[1:] + if format.upper() not in soundfile.available_formats(): + # If not found, flac is selected + format = "flac" + + # This format affects only saving + self.format = format + + def __repr__(self): + return ''.format( + self.filepath, self.mode, self.format, self.dtype) + + def create_dataset(self, name, shape=None, data=None, **kwds): + f = io.BytesIO() + array, rate = data + soundfile.write(f, array, rate, format=self.format) + self.file.create_dataset( + name, shape=shape, data=np.void(f.getvalue()), **kwds) + + def __setitem__(self, name, data): + self.create_dataset(name, data=data) + + def __getitem__(self, key): + data = self.file[key][()] + f = io.BytesIO(data.tobytes()) + array, rate = soundfile.read(f, dtype=self.dtype) + return array, rate + + def keys(self): + return self.file.keys() + + def values(self): + for k in self.file: + yield self[k] + + def items(self): + for k in self.file: + yield k, self[k] + + def __iter__(self): + return iter(self.file) + + def __contains__(self, item): + return item in self.file + + def __len__(self, item): + return len(self.file) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def close(self): + self.file.close()