test dataloader

pull/521/head
Hui Zhang 5 years ago
parent a01dc81474
commit 006504c4e7

@ -18,6 +18,9 @@ import numpy as np
import paddle
from paddle.io import Dataset
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from collections import namedtuple
from functools import partial
from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline
@ -58,7 +61,7 @@ class DeepSpeech2Dataset(Dataset):
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
self._local_data = local()
self._local_data = namedtuple('local_data', ['tar2info', 'tar2object'])
self._local_data.tar2info = {}
self._local_data.tar2object = {}
@ -163,7 +166,8 @@ class DeepSpeech2Dataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["audio_filepath"], instance["text"])
return self.process_utterance(instance["audio_filepath"],
instance["text"])
class DeepSpeech2BatchSampler(DistributedBatchSampler):
@ -176,9 +180,8 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
drop_last=False,
sortagrad=False,
shuffle_method="batch_shuffle"):
super().__init__(
dataset, batch_size, num_replicas, rank, shuffle, drop_last
)
super().__init__(dataset, batch_size, num_replicas, rank, shuffle,
drop_last)
self._sortagrad = sortagrad
self._shuffle_method = shuffle_method
@ -203,7 +206,7 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
:return: Batch shuffled mainifest.
:rtype: list
"""
rng = np.random.RandomState(self.epoch).
rng = np.random.RandomState(self.epoch)
manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
@ -250,8 +253,8 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
subsampled_indices.extend(
indices[self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
@ -276,8 +279,7 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
return num_samples // self.batch_size
def create_dataloader(
manifest_path,
def create_dataloader(manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
@ -291,7 +293,8 @@ def create_dataloader(
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=args.num_samples,
batch_size=1,
num_workers=0,
sortagrad=False,
shuffle_method=None):
@ -320,7 +323,7 @@ def create_dataloader(
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(self, batch, padding_to=-1, flatten=False):
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
@ -351,24 +354,20 @@ def create_dataloader(
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
if self._is_training:
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
else:
texts.append(text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
if self._is_training:
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(dataset,
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=padding_batch,
num_workers=2,
)
collate_fn=partial(padding_batch, is_training=is_training),
num_workers=num_workers, )
return loader

@ -77,9 +77,9 @@ def infer():
"""Inference for DeepSpeech2."""
# check if set use_gpu=True in paddlepaddle cpu version
check_cuda(args.use_gpu)
#check_cuda(args.use_gpu)
# check if paddlepaddle version is satisfied
check_version()
#check_version()
# data_generator = DataGenerator(
# vocab_filepath=args.vocab_path,
@ -114,7 +114,14 @@ def infer():
sortagrad=False,
shuffle_method=None)
infer_data = next(batch_reader())
for audio, text, audio_len, text_len in batch_reader:
print(audio.shape)
print(text.shape)
print(audio_len)
print(text_len)
break
infer_data = batch_reader()
print(infer_data)
# ds2_model = DeepSpeech2Model(

Loading…
Cancel
Save