test dataloader

pull/521/head
Hui Zhang 5 years ago
parent a01dc81474
commit 006504c4e7

@ -18,6 +18,9 @@ import numpy as np
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from collections import namedtuple
from functools import partial
from data_utils.utility import read_manifest from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.augmentor.augmentation import AugmentationPipeline
@ -27,20 +30,20 @@ from data_utils.normalizer import FeatureNormalizer
class DeepSpeech2Dataset(Dataset): class DeepSpeech2Dataset(Dataset):
def __init__(self, def __init__(self,
manifest_path, manifest_path,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
augmentation_config='{}', augmentation_config='{}',
max_duration=float('inf'), max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', specgram_type='linear',
use_dB_normalization=True, use_dB_normalization=True,
random_seed=0, random_seed=0,
keep_transcription_text=False): keep_transcription_text=False):
super().__init__() super().__init__()
self._max_duration = max_duration self._max_duration = max_duration
@ -58,7 +61,7 @@ class DeepSpeech2Dataset(Dataset):
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
# for caching tar files info # for caching tar files info
self._local_data = local() self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])
self._local_data.tar2info = {} self._local_data.tar2info = {}
self._local_data.tar2object = {} self._local_data.tar2object = {}
@ -160,25 +163,25 @@ class DeepSpeech2Dataset(Dataset):
def __len__(self): def __len__(self):
return len(self._manifest) return len(self._manifest)
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
return self.process_utterance(instance["audio_filepath"], instance["text"]) return self.process_utterance(instance["audio_filepath"],
instance["text"])
class DeepSpeech2BatchSampler(DistributedBatchSampler): class DeepSpeech2BatchSampler(DistributedBatchSampler):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
num_replicas=None, num_replicas=None,
rank=None, rank=None,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
sortagrad=False, sortagrad=False,
shuffle_method="batch_shuffle"): shuffle_method="batch_shuffle"):
super().__init__( super().__init__(dataset, batch_size, num_replicas, rank, shuffle,
dataset, batch_size, num_replicas, rank, shuffle, drop_last drop_last)
)
self._sortagrad = sortagrad self._sortagrad = sortagrad
self._shuffle_method = shuffle_method self._shuffle_method = shuffle_method
@ -203,7 +206,7 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
:return: Batch shuffled mainifest. :return: Batch shuffled mainifest.
:rtype: list :rtype: list
""" """
rng = np.random.RandomState(self.epoch). rng = np.random.RandomState(self.epoch)
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
@ -250,8 +253,8 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
subsampled_indices.extend(indices[i:i + self.batch_size]) subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:] indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[ subsampled_indices.extend(
self.local_rank * last_local_batch_size:( indices[self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size]) self.local_rank + 1) * last_local_batch_size])
return subsampled_indices return subsampled_indices
@ -276,39 +279,39 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
return num_samples // self.batch_size return num_samples // self.batch_size
def create_dataloader( def create_dataloader(manifest_path,
manifest_path, vocab_filepath,
vocab_filepath, mean_std_filepath,
mean_std_filepath, augmentation_config='{}',
augmentation_config='{}', max_duration=float('inf'),
max_duration=float('inf'), min_duration=0.0,
min_duration=0.0, stride_ms=10.0,
stride_ms=10.0, window_ms=20.0,
window_ms=20.0, max_freq=None,
max_freq=None, specgram_type='linear',
specgram_type='linear', use_dB_normalization=True,
use_dB_normalization=True, random_seed=0,
random_seed=0, keep_transcription_text=False,
keep_transcription_text=False, is_training=False,
is_training=False, batch_size=1,
batch_size=args.num_samples, num_workers=0,
sortagrad=False, sortagrad=False,
shuffle_method=None): shuffle_method=None):
dataset = DeepSpeech2Dataset( dataset = DeepSpeech2Dataset(
manifest_path, manifest_path,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
augmentation_config=augmentation_config, augmentation_config=augmentation_config,
max_duration=max_duration, max_duration=max_duration,
min_duration=min_duration, min_duration=min_duration,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq, max_freq=max_freq,
specgram_type=specgram_type, specgram_type=specgram_type,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
random_seed=random_seed, random_seed=random_seed,
keep_transcription_text=keep_transcription_text) keep_transcription_text=keep_transcription_text)
batch_sampler = DeepSpeech2BatchSampler( batch_sampler = DeepSpeech2BatchSampler(
dataset, dataset,
@ -320,7 +323,7 @@ def create_dataloader(
sortagrad=is_training, sortagrad=is_training,
shuffle_method=shuffle_method) shuffle_method=shuffle_method)
def padding_batch(self, batch, padding_to=-1, flatten=False): def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
""" """
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach. a user-defined shape) within one bach.
@ -351,24 +354,20 @@ def create_dataloader(
padded_audio = padded_audio.flatten() padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio) padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1]) audio_lens.append(audio.shape[1])
if self._is_training: padded_text = np.zeros([max_text_length])
padded_text = np.zeros([max_text_length]) padded_text[:len(text)] = text
padded_text[:len(text)] = text texts.append(padded_text)
texts.append(padded_text)
else:
texts.append(text)
text_lens.append(len(text)) text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32') padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64') audio_lens = np.array(audio_lens).astype('int64')
if self._is_training: texts = np.array(texts).astype('int32')
texts = np.array(texts).astype('int32') text_lens = np.array(text_lens).astype('int64')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(dataset, loader = DataLoader(
dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=padding_batch, collate_fn=partial(padding_batch, is_training=is_training),
num_workers=2, num_workers=num_workers, )
) return loader
return loader

@ -77,9 +77,9 @@ def infer():
"""Inference for DeepSpeech2.""" """Inference for DeepSpeech2."""
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
check_cuda(args.use_gpu) #check_cuda(args.use_gpu)
# check if paddlepaddle version is satisfied # check if paddlepaddle version is satisfied
check_version() #check_version()
# data_generator = DataGenerator( # data_generator = DataGenerator(
# vocab_filepath=args.vocab_path, # vocab_filepath=args.vocab_path,
@ -114,7 +114,14 @@ def infer():
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
infer_data = next(batch_reader()) for audio, text, audio_len, text_len in batch_reader:
print(audio.shape)
print(text.shape)
print(audio_len)
print(text_len)
break
infer_data = batch_reader()
print(infer_data) print(infer_data)
# ds2_model = DeepSpeech2Model( # ds2_model = DeepSpeech2Model(

Loading…
Cancel
Save