add more comments and update train.py

pull/2/head
dangqingqing 7 years ago
parent bf73540067
commit 9c27b1d14e

@ -247,25 +247,34 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def __batch_shuffle__(self, manifest, batch_shuffle_size): def __batch_shuffle__(self, manifest, batch_size):
""" """
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration. 1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_shuffle_size). 2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches, 3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_shuffle_size. then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches. 4. Shuffle the minibatches.
:param manifest: manifest file. :param manifest: manifest file.
:type manifest: list :type manifest: list
:param batch_shuffle_size: This size is uesed to generate a random number, :param batch_size: Batch size. This size is also used for generate
it usually equals to batch size. a random number for batch shuffle.
:type batch_shuffle_size: int :type batch_size: int
:return: batch shuffled mainifest. :return: batch shuffled mainifest.
:rtype: list :rtype: list
""" """
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_shuffle_size - 1) shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size) batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest) self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ())) batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest) res_len = len(manifest) - shift_len - len(batch_manifest)
@ -327,8 +336,9 @@ class DataGenerator(object):
if set True. if set True.
:type sortagrad: bool :type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is :param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, not a thorough instance-wise shuffle, but a
but a specific batch-wise shuffle. specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool :type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called. :return: Batch reader function, producing batches of data when called.
:rtype: callable :rtype: callable

@ -143,12 +143,12 @@ def train():
train_batch_reader = train_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
sortagrad=True, sortagrad=True if args.init_model_path is None else False,
shuffle=True) batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator( test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False) batch_shuffle=False)
feeding = train_generator.data_name_feeding() feeding = train_generator.data_name_feeding()
# create event handler # create event handler

Loading…
Cancel
Save