From 1d8cc4a5a9bfd9eff50a9a971411333e9050ff83 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Tue, 20 Jun 2017 17:06:53 +0800 Subject: [PATCH] Add multi-threading support for DS2 data generator. --- data_utils/data.py | 14 +++++++++++--- data_utils/speech.py | 2 +- infer.py | 8 +++++++- train.py | 22 +++++++++++++++++++++- 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/data_utils/data.py b/data_utils/data.py index 424343a4..8391dacc 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -44,6 +44,8 @@ class DataGenerator(object): :types max_freq: None|float :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str + :param num_threads: Number of CPU threads for processing data. + :type num_threads: int :param random_seed: Random seed. :type random_seed: int """ @@ -58,6 +60,7 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', + num_threads=12, random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -70,6 +73,7 @@ class DataGenerator(object): stride_ms=stride_ms, window_ms=window_ms, max_freq=max_freq) + self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 @@ -207,10 +211,14 @@ class DataGenerator(object): def reader(): for instance in manifest: - yield self._process_utterance(instance["audio_filepath"], - instance["text"]) + yield instance - return reader + def mapper(instance): + return self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return paddle.reader.xmap_readers( + mapper, reader, self._num_threads, 1024, order=True) def _padding_batch(self, batch, padding_to=-1, flatten=False): """ diff --git a/data_utils/speech.py b/data_utils/speech.py index fc031ff4..568e4443 100644 --- a/data_utils/speech.py +++ b/data_utils/speech.py @@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment): return cls(samples, sample_rate, transcripts) @classmethod - def slice_from_file(cls, filepath, start=None, end=None, transcript): + def slice_from_file(cls, filepath, transcript, start=None, end=None): """Loads a small section of an speech without having to load the entire file into the memory which can be incredibly wasteful. diff --git a/infer.py b/infer.py index 06449ab0..7fc84829 100644 --- a/infer.py +++ b/infer.py @@ -38,6 +38,11 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=12, + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -67,7 +72,8 @@ def infer(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. diff --git a/train.py b/train.py index c60a039b..2c3b8ce7 100644 --- a/train.py +++ b/train.py @@ -52,6 +52,18 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--max_duration", + default=100.0, + type=float, + help="Audios with duration larger than this will be discarded. " + "(default: %(default)s)") +parser.add_argument( + "--min_duration", + default=0.0, + type=float, + help="Audios with duration smaller than this will be discarded. " + "(default: %(default)s)") parser.add_argument( "--shuffle_method", default='instance_shuffle', @@ -63,6 +75,11 @@ parser.add_argument( default=4, type=int, help="Trainer number. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=12, + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -107,7 +124,10 @@ def train(): return DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config=args.augmentation_config) + augmentation_config=args.augmentation_config, + max_duration=args.max_duration, + min_duration=args.min_duration, + num_threads=args.num_threads_data) train_generator = data_generator() test_generator = data_generator()