fix a deep speech 2 speed bug

pull/2/head
lispczz 7 years ago
parent e909396f91
commit 0dcf13f0fc

@ -327,7 +327,7 @@ class DataGenerator(object):
shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self._rng.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])

Loading…
Cancel
Save