create dataloader

pull/521/head
zhanghui41 5 years ago
parent 920c14c0d8
commit a01dc81474

@ -187,6 +187,9 @@ class DataGenerator():
manifest_path=manifest_path, manifest_path=manifest_path,
max_duration=self._max_duration, max_duration=self._max_duration,
min_duration=self._min_duration) min_duration=self._min_duration)
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad: if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])

@ -12,11 +12,363 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import tarfile
import numpy as np
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.io import DataLoader from paddle.io import DataLoader
from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
class DeepSpeech2Dataset(Dataset): class DeepSpeech2Dataset(Dataset):
def __init__(self): def __init__(self,
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False):
super().__init__() super().__init__()
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
vocab_filepath=vocab_filepath,
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
# read manifest
self._manifest = read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
self._manifest.sort(key=lambda x: x["duration"])
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return self._speech_featurizer.vocab_list
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, transcript_part
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
instances of data.
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def reader():
for instance in manifest:
inst = self.process_utterance(instance["audio_filepath"],
instance["text"])
yield inst
return reader
def __len__(self):
return len(self._manifest)
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["audio_filepath"], instance["text"])
class DeepSpeech2BatchSampler(DistributedBatchSampler):
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False,
sortagrad=False,
shuffle_method="batch_shuffle"):
super().__init__(
dataset, batch_size, num_replicas, rank, shuffle, drop_last
)
self._sortagrad = sortagrad
self._shuffle_method = shuffle_method
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly shift `k` instances in order to create different batches
for different epochs. Create minibatches.
4. Shuffle the minibatches.
:param manifest: Manifest contents. List of dict.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
rng = np.random.RandomState(self.epoch).
manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# sort (by duration) or batch-wise shuffle the manifest
if self.shuffle:
if self.epoch == 0 and self.sortagrad:
pass
else:
if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle(
indices, self.batch_size, clipped=False)
elif self._shuffle_method == "instance_shuffle":
np.random.RandomState(self.epoch).shuffle(indices)
else:
raise ValueError("Unknown shuffle method %s." %
self._shuffle_method)
assert len(indices) == self.total_size
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def create_dataloader(
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=args.num_samples,
sortagrad=False,
shuffle_method=None):
dataset = DeepSpeech2Dataset(
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
specgram_type=specgram_type,
use_dB_normalization=use_dB_normalization,
random_seed=random_seed,
keep_transcription_text=keep_transcription_text)
batch_sampler = DeepSpeech2BatchSampler(
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=is_training,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(self, batch, padding_to=-1, flatten=False):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
max_text_length = max([len(text) for audio, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
if self._is_training:
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
else:
texts.append(text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
if self._is_training:
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(dataset,
batch_sampler=batch_sampler,
collate_fn=padding_batch,
num_workers=2,
)
return loader

@ -18,6 +18,7 @@ import argparse
import functools import functools
import paddle.fluid as fluid import paddle.fluid as fluid
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from data_utils.dataset import create_dataloader
from model_utils.model import DeepSpeech2Model from model_utils.model import DeepSpeech2Model
from model_utils.model_check import check_cuda, check_version from model_utils.model_check import check_cuda, check_version
from utils.error_rate import wer, cer from utils.error_rate import wer, cer
@ -80,75 +81,91 @@ def infer():
# check if paddlepaddle version is satisfied # check if paddlepaddle version is satisfied
check_version() check_version()
if args.use_gpu: # data_generator = DataGenerator(
place = fluid.CUDAPlace(0) # vocab_filepath=args.vocab_path,
else: # mean_std_filepath=args.mean_std_path,
place = fluid.CPUPlace() # augmentation_config='{}',
# specgram_type=args.specgram_type,
data_generator = DataGenerator( # keep_transcription_text=True,
vocab_filepath=args.vocab_path, # place = place,
mean_std_filepath=args.mean_std_path, # is_training = False)
augmentation_config='{}', # batch_reader = data_generator.batch_reader_creator(
specgram_type=args.specgram_type, # manifest_path=args.infer_manifest,
keep_transcription_text=True, # batch_size=args.num_samples,
place = place, # sortagrad=False,
is_training = False) # shuffle_method=None)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest, batch_reader = create_dataloader(
batch_size=args.num_samples, manifest_path=args.infer_manifest,
sortagrad=False, vocab_filepath=args.vocab_path,
shuffle_method=None) mean_std_filepath=args.mean_std_path,
infer_data = next(batch_reader()) augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type=args.specgram_type,
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=args.num_samples,
sortagrad=False,
shuffle_method=None)
ds2_model = DeepSpeech2Model( infer_data = next(batch_reader())
vocab_size=data_generator.vocab_size, print(infer_data)
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, # ds2_model = DeepSpeech2Model(
rnn_layer_size=args.rnn_layer_size, # vocab_size=data_generator.vocab_size,
use_gru=args.use_gru, # num_conv_layers=args.num_conv_layers,
share_rnn_weights=args.share_rnn_weights, # num_rnn_layers=args.num_rnn_layers,
place=place, # rnn_layer_size=args.rnn_layer_size,
init_from_pretrained_model=args.model_path) # use_gru=args.use_gru,
# share_rnn_weights=args.share_rnn_weights,
# decoders only accept string encoded in utf-8 # place=place,
vocab_list = [chars for chars in data_generator.vocab_list] # init_from_pretrained_model=args.model_path)
if args.decoding_method == "ctc_greedy": # # decoders only accept string encoded in utf-8
ds2_model.logger.info("start inference ...") # vocab_list = [chars for chars in data_generator.vocab_list]
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data, # if args.decoding_method == "ctc_greedy":
feeding_dict=data_generator.feeding) # ds2_model.logger.info("start inference ...")
# probs_split = ds2_model.infer_batch_probs(
result_transcripts = ds2_model.decode_batch_greedy( # infer_data=infer_data,
probs_split=probs_split, # feeding_dict=data_generator.feeding)
vocab_list=vocab_list)
else: # result_transcripts = ds2_model.decode_batch_greedy(
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, # probs_split=probs_split,
vocab_list) # vocab_list=vocab_list)
ds2_model.logger.info("start inference ...") # else:
probs_split= ds2_model.infer_batch_probs( # ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
infer_data=infer_data, # vocab_list)
feeding_dict=data_generator.feeding) # ds2_model.logger.info("start inference ...")
# probs_split= ds2_model.infer_batch_probs(
result_transcripts= ds2_model.decode_batch_beam_search( # infer_data=infer_data,
probs_split=probs_split, # feeding_dict=data_generator.feeding)
beam_alpha=args.alpha,
beam_beta=args.beta, # result_transcripts= ds2_model.decode_batch_beam_search(
beam_size=args.beam_size, # probs_split=probs_split,
cutoff_prob=args.cutoff_prob, # beam_alpha=args.alpha,
cutoff_top_n=args.cutoff_top_n, # beam_beta=args.beta,
vocab_list=vocab_list, # beam_size=args.beam_size,
num_processes=args.num_proc_bsearch) # cutoff_prob=args.cutoff_prob,
# cutoff_top_n=args.cutoff_top_n,
error_rate_func = cer if args.error_rate_type == 'cer' else wer # vocab_list=vocab_list,
target_transcripts = infer_data[1] # num_processes=args.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" % # error_rate_func = cer if args.error_rate_type == 'cer' else wer
(target, result)) # target_transcripts = infer_data[1]
print("Current error rate [%s] = %f" % # for target, result in zip(target_transcripts, result_transcripts):
(args.error_rate_type, error_rate_func(target, result))) # print("\nTarget Transcription: %s\nOutput Transcription: %s" %
# (target, result))
ds2_model.logger.info("finish inference") # print("Current error rate [%s] = %f" %
# (args.error_rate_type, error_rate_func(target, result)))
# ds2_model.logger.info("finish inference")
def main(): def main():
print_arguments(args) print_arguments(args)

Loading…
Cancel
Save