|
|
|
@ -18,6 +18,9 @@ import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.io import Dataset
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
from data_utils.utility import read_manifest
|
|
|
|
|
from data_utils.augmentor.augmentation import AugmentationPipeline
|
|
|
|
@ -28,19 +31,19 @@ from data_utils.normalizer import FeatureNormalizer
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Dataset(Dataset):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config='{}',
|
|
|
|
|
max_duration=float('inf'),
|
|
|
|
|
min_duration=0.0,
|
|
|
|
|
stride_ms=10.0,
|
|
|
|
|
window_ms=20.0,
|
|
|
|
|
max_freq=None,
|
|
|
|
|
specgram_type='linear',
|
|
|
|
|
use_dB_normalization=True,
|
|
|
|
|
random_seed=0,
|
|
|
|
|
keep_transcription_text=False):
|
|
|
|
|
manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config='{}',
|
|
|
|
|
max_duration=float('inf'),
|
|
|
|
|
min_duration=0.0,
|
|
|
|
|
stride_ms=10.0,
|
|
|
|
|
window_ms=20.0,
|
|
|
|
|
max_freq=None,
|
|
|
|
|
specgram_type='linear',
|
|
|
|
|
use_dB_normalization=True,
|
|
|
|
|
random_seed=0,
|
|
|
|
|
keep_transcription_text=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self._max_duration = max_duration
|
|
|
|
@ -58,7 +61,7 @@ class DeepSpeech2Dataset(Dataset):
|
|
|
|
|
self._rng = random.Random(random_seed)
|
|
|
|
|
self._keep_transcription_text = keep_transcription_text
|
|
|
|
|
# for caching tar files info
|
|
|
|
|
self._local_data = local()
|
|
|
|
|
self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])
|
|
|
|
|
self._local_data.tar2info = {}
|
|
|
|
|
self._local_data.tar2object = {}
|
|
|
|
|
|
|
|
|
@ -163,22 +166,22 @@ class DeepSpeech2Dataset(Dataset):
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
instance = self._manifest[idx]
|
|
|
|
|
return self.process_utterance(instance["audio_filepath"], instance["text"])
|
|
|
|
|
return self.process_utterance(instance["audio_filepath"],
|
|
|
|
|
instance["text"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2BatchSampler(DistributedBatchSampler):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size,
|
|
|
|
|
num_replicas=None,
|
|
|
|
|
rank=None,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
shuffle_method="batch_shuffle"):
|
|
|
|
|
super().__init__(
|
|
|
|
|
dataset, batch_size, num_replicas, rank, shuffle, drop_last
|
|
|
|
|
)
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size,
|
|
|
|
|
num_replicas=None,
|
|
|
|
|
rank=None,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
shuffle_method="batch_shuffle"):
|
|
|
|
|
super().__init__(dataset, batch_size, num_replicas, rank, shuffle,
|
|
|
|
|
drop_last)
|
|
|
|
|
self._sortagrad = sortagrad
|
|
|
|
|
self._shuffle_method = shuffle_method
|
|
|
|
|
|
|
|
|
@ -203,7 +206,7 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
|
|
|
|
|
:return: Batch shuffled mainifest.
|
|
|
|
|
:rtype: list
|
|
|
|
|
"""
|
|
|
|
|
rng = np.random.RandomState(self.epoch).
|
|
|
|
|
rng = np.random.RandomState(self.epoch)
|
|
|
|
|
manifest.sort(key=lambda x: x["duration"])
|
|
|
|
|
shift_len = rng.randint(0, batch_size - 1)
|
|
|
|
|
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
|
|
|
|
@ -250,8 +253,8 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
|
|
|
|
|
subsampled_indices.extend(indices[i:i + self.batch_size])
|
|
|
|
|
|
|
|
|
|
indices = indices[len(indices) - last_batch_size:]
|
|
|
|
|
subsampled_indices.extend(indices[
|
|
|
|
|
self.local_rank * last_local_batch_size:(
|
|
|
|
|
subsampled_indices.extend(
|
|
|
|
|
indices[self.local_rank * last_local_batch_size:(
|
|
|
|
|
self.local_rank + 1) * last_local_batch_size])
|
|
|
|
|
return subsampled_indices
|
|
|
|
|
|
|
|
|
@ -276,39 +279,39 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
|
|
|
|
|
return num_samples // self.batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataloader(
|
|
|
|
|
manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config='{}',
|
|
|
|
|
max_duration=float('inf'),
|
|
|
|
|
min_duration=0.0,
|
|
|
|
|
stride_ms=10.0,
|
|
|
|
|
window_ms=20.0,
|
|
|
|
|
max_freq=None,
|
|
|
|
|
specgram_type='linear',
|
|
|
|
|
use_dB_normalization=True,
|
|
|
|
|
random_seed=0,
|
|
|
|
|
keep_transcription_text=False,
|
|
|
|
|
is_training=False,
|
|
|
|
|
batch_size=args.num_samples,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
shuffle_method=None):
|
|
|
|
|
def create_dataloader(manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config='{}',
|
|
|
|
|
max_duration=float('inf'),
|
|
|
|
|
min_duration=0.0,
|
|
|
|
|
stride_ms=10.0,
|
|
|
|
|
window_ms=20.0,
|
|
|
|
|
max_freq=None,
|
|
|
|
|
specgram_type='linear',
|
|
|
|
|
use_dB_normalization=True,
|
|
|
|
|
random_seed=0,
|
|
|
|
|
keep_transcription_text=False,
|
|
|
|
|
is_training=False,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
num_workers=0,
|
|
|
|
|
sortagrad=False,
|
|
|
|
|
shuffle_method=None):
|
|
|
|
|
|
|
|
|
|
dataset = DeepSpeech2Dataset(
|
|
|
|
|
manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config=augmentation_config,
|
|
|
|
|
max_duration=max_duration,
|
|
|
|
|
min_duration=min_duration,
|
|
|
|
|
stride_ms=stride_ms,
|
|
|
|
|
window_ms=window_ms,
|
|
|
|
|
max_freq=max_freq,
|
|
|
|
|
specgram_type=specgram_type,
|
|
|
|
|
use_dB_normalization=use_dB_normalization,
|
|
|
|
|
random_seed=random_seed,
|
|
|
|
|
keep_transcription_text=keep_transcription_text)
|
|
|
|
|
manifest_path,
|
|
|
|
|
vocab_filepath,
|
|
|
|
|
mean_std_filepath,
|
|
|
|
|
augmentation_config=augmentation_config,
|
|
|
|
|
max_duration=max_duration,
|
|
|
|
|
min_duration=min_duration,
|
|
|
|
|
stride_ms=stride_ms,
|
|
|
|
|
window_ms=window_ms,
|
|
|
|
|
max_freq=max_freq,
|
|
|
|
|
specgram_type=specgram_type,
|
|
|
|
|
use_dB_normalization=use_dB_normalization,
|
|
|
|
|
random_seed=random_seed,
|
|
|
|
|
keep_transcription_text=keep_transcription_text)
|
|
|
|
|
|
|
|
|
|
batch_sampler = DeepSpeech2BatchSampler(
|
|
|
|
|
dataset,
|
|
|
|
@ -320,7 +323,7 @@ def create_dataloader(
|
|
|
|
|
sortagrad=is_training,
|
|
|
|
|
shuffle_method=shuffle_method)
|
|
|
|
|
|
|
|
|
|
def padding_batch(self, batch, padding_to=-1, flatten=False):
|
|
|
|
|
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
|
|
|
|
|
"""
|
|
|
|
|
Padding audio features with zeros to make them have the same shape (or
|
|
|
|
|
a user-defined shape) within one bach.
|
|
|
|
@ -351,24 +354,20 @@ def create_dataloader(
|
|
|
|
|
padded_audio = padded_audio.flatten()
|
|
|
|
|
padded_audios.append(padded_audio)
|
|
|
|
|
audio_lens.append(audio.shape[1])
|
|
|
|
|
if self._is_training:
|
|
|
|
|
padded_text = np.zeros([max_text_length])
|
|
|
|
|
padded_text[:len(text)] = text
|
|
|
|
|
texts.append(padded_text)
|
|
|
|
|
else:
|
|
|
|
|
texts.append(text)
|
|
|
|
|
padded_text = np.zeros([max_text_length])
|
|
|
|
|
padded_text[:len(text)] = text
|
|
|
|
|
texts.append(padded_text)
|
|
|
|
|
text_lens.append(len(text))
|
|
|
|
|
|
|
|
|
|
padded_audios = np.array(padded_audios).astype('float32')
|
|
|
|
|
audio_lens = np.array(audio_lens).astype('int64')
|
|
|
|
|
if self._is_training:
|
|
|
|
|
texts = np.array(texts).astype('int32')
|
|
|
|
|
text_lens = np.array(text_lens).astype('int64')
|
|
|
|
|
texts = np.array(texts).astype('int32')
|
|
|
|
|
text_lens = np.array(text_lens).astype('int64')
|
|
|
|
|
return padded_audios, texts, audio_lens, text_lens
|
|
|
|
|
|
|
|
|
|
loader = DataLoader(dataset,
|
|
|
|
|
loader = DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
collate_fn=padding_batch,
|
|
|
|
|
num_workers=2,
|
|
|
|
|
)
|
|
|
|
|
collate_fn=partial(padding_batch, is_training=is_training),
|
|
|
|
|
num_workers=num_workers, )
|
|
|
|
|
return loader
|