diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 2fc42cc46..667b6fbe5 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -18,6 +18,9 @@ import numpy as np import paddle from paddle.io import Dataset from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from collections import namedtuple +from functools import partial from data_utils.utility import read_manifest from data_utils.augmentor.augmentation import AugmentationPipeline @@ -27,20 +30,20 @@ from data_utils.normalizer import FeatureNormalizer class DeepSpeech2Dataset(Dataset): - def __init__(self, - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - specgram_type='linear', - use_dB_normalization=True, - random_seed=0, - keep_transcription_text=False): + def __init__(self, + manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + specgram_type='linear', + use_dB_normalization=True, + random_seed=0, + keep_transcription_text=False): super().__init__() self._max_duration = max_duration @@ -58,7 +61,7 @@ class DeepSpeech2Dataset(Dataset): self._rng = random.Random(random_seed) self._keep_transcription_text = keep_transcription_text # for caching tar files info - self._local_data = local() + self._local_data = namedtuple('local_data', ['tar2info', 'tar2object']) self._local_data.tar2info = {} self._local_data.tar2object = {} @@ -160,25 +163,25 @@ class DeepSpeech2Dataset(Dataset): def __len__(self): return len(self._manifest) - + def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["audio_filepath"], instance["text"]) + return self.process_utterance(instance["audio_filepath"], + instance["text"]) class DeepSpeech2BatchSampler(DistributedBatchSampler): def __init__(self, - dataset, - batch_size, - num_replicas=None, - rank=None, - shuffle=False, - drop_last=False, - sortagrad=False, - shuffle_method="batch_shuffle"): - super().__init__( - dataset, batch_size, num_replicas, rank, shuffle, drop_last - ) + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False, + sortagrad=False, + shuffle_method="batch_shuffle"): + super().__init__(dataset, batch_size, num_replicas, rank, shuffle, + drop_last) self._sortagrad = sortagrad self._shuffle_method = shuffle_method @@ -203,7 +206,7 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler): :return: Batch shuffled mainifest. :rtype: list """ - rng = np.random.RandomState(self.epoch). + rng = np.random.RandomState(self.epoch) manifest.sort(key=lambda x: x["duration"]) shift_len = rng.randint(0, batch_size - 1) batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) @@ -250,8 +253,8 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler): subsampled_indices.extend(indices[i:i + self.batch_size]) indices = indices[len(indices) - last_batch_size:] - subsampled_indices.extend(indices[ - self.local_rank * last_local_batch_size:( + subsampled_indices.extend( + indices[self.local_rank * last_local_batch_size:( self.local_rank + 1) * last_local_batch_size]) return subsampled_indices @@ -276,39 +279,39 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler): return num_samples // self.batch_size -def create_dataloader( - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - specgram_type='linear', - use_dB_normalization=True, - random_seed=0, - keep_transcription_text=False, - is_training=False, - batch_size=args.num_samples, - sortagrad=False, - shuffle_method=None): +def create_dataloader(manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + specgram_type='linear', + use_dB_normalization=True, + random_seed=0, + keep_transcription_text=False, + is_training=False, + batch_size=1, + num_workers=0, + sortagrad=False, + shuffle_method=None): dataset = DeepSpeech2Dataset( - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config=augmentation_config, - max_duration=max_duration, - min_duration=min_duration, - stride_ms=stride_ms, - window_ms=window_ms, - max_freq=max_freq, - specgram_type=specgram_type, - use_dB_normalization=use_dB_normalization, - random_seed=random_seed, - keep_transcription_text=keep_transcription_text) + manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config=augmentation_config, + max_duration=max_duration, + min_duration=min_duration, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + specgram_type=specgram_type, + use_dB_normalization=use_dB_normalization, + random_seed=random_seed, + keep_transcription_text=keep_transcription_text) batch_sampler = DeepSpeech2BatchSampler( dataset, @@ -320,7 +323,7 @@ def create_dataloader( sortagrad=is_training, shuffle_method=shuffle_method) - def padding_batch(self, batch, padding_to=-1, flatten=False): + def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): """ Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. @@ -351,24 +354,20 @@ def create_dataloader( padded_audio = padded_audio.flatten() padded_audios.append(padded_audio) audio_lens.append(audio.shape[1]) - if self._is_training: - padded_text = np.zeros([max_text_length]) - padded_text[:len(text)] = text - texts.append(padded_text) - else: - texts.append(text) + padded_text = np.zeros([max_text_length]) + padded_text[:len(text)] = text + texts.append(padded_text) text_lens.append(len(text)) padded_audios = np.array(padded_audios).astype('float32') audio_lens = np.array(audio_lens).astype('int64') - if self._is_training: - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') + texts = np.array(texts).astype('int32') + text_lens = np.array(text_lens).astype('int64') return padded_audios, texts, audio_lens, text_lens - loader = DataLoader(dataset, + loader = DataLoader( + dataset, batch_sampler=batch_sampler, - collate_fn=padding_batch, - num_workers=2, - ) - return loader \ No newline at end of file + collate_fn=partial(padding_batch, is_training=is_training), + num_workers=num_workers, ) + return loader diff --git a/infer.py b/infer.py index 07da84451..3c9171566 100644 --- a/infer.py +++ b/infer.py @@ -77,9 +77,9 @@ def infer(): """Inference for DeepSpeech2.""" # check if set use_gpu=True in paddlepaddle cpu version - check_cuda(args.use_gpu) + #check_cuda(args.use_gpu) # check if paddlepaddle version is satisfied - check_version() + #check_version() # data_generator = DataGenerator( # vocab_filepath=args.vocab_path, @@ -114,7 +114,14 @@ def infer(): sortagrad=False, shuffle_method=None) - infer_data = next(batch_reader()) + for audio, text, audio_len, text_len in batch_reader: + print(audio.shape) + print(text.shape) + print(audio_len) + print(text_len) + break + + infer_data = batch_reader() print(infer_data) # ds2_model = DeepSpeech2Model(