one proc dataloader test pass

pull/521/head
Hui Zhang 5 years ago
parent 006504c4e7
commit 20e5bea192

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import random import random
import tarfile import tarfile
import numpy as np import numpy as np
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import BatchSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
@ -170,7 +172,7 @@ class DeepSpeech2Dataset(Dataset):
instance["text"]) instance["text"])
class DeepSpeech2BatchSampler(DistributedBatchSampler): class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
@ -279,6 +281,179 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler):
return num_samples // self.batch_size return num_samples // self.batch_size
class DeepSpeech2BatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
shuffle=False,
drop_last=False,
sortagrad=False,
shuffle_method="batch_shuffle",
num_replicas=1,
rank=0):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = num_replicas
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = rank
self.drop_last = drop_last
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
self._sortagrad = sortagrad
self._shuffle_method = shuffle_method
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly shift `k` instances in order to create different batches
for different epochs. Create minibatches.
4. Shuffle the minibatches.
:param manifest: Manifest contents. List of dict.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
rng = np.random.RandomState(self.epoch)
manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# sort (by duration) or batch-wise shuffle the manifest
if self.shuffle:
if self.epoch == 0 and self.sortagrad:
pass
else:
if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle(
indices, self.batch_size, clipped=False)
elif self._shuffle_method == "instance_shuffle":
np.random.RandomState(self.epoch).shuffle(indices)
else:
raise ValueError("Unknown shuffle method %s." %
self._shuffle_method)
assert len(indices) == self.total_size
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(
indices[self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self.epoch = epoch
def create_dataloader(manifest_path, def create_dataloader(manifest_path,
vocab_filepath, vocab_filepath,
mean_std_filepath, mean_std_filepath,
@ -296,7 +471,8 @@ def create_dataloader(manifest_path,
batch_size=1, batch_size=1,
num_workers=0, num_workers=0,
sortagrad=False, sortagrad=False,
shuffle_method=None): shuffle_method=None,
dist=False):
dataset = DeepSpeech2Dataset( dataset = DeepSpeech2Dataset(
manifest_path, manifest_path,
@ -313,15 +489,24 @@ def create_dataloader(manifest_path,
random_seed=random_seed, random_seed=random_seed,
keep_transcription_text=keep_transcription_text) keep_transcription_text=keep_transcription_text)
batch_sampler = DeepSpeech2BatchSampler( if dist:
dataset, batch_sampler = DeepSpeech2DistributedBatchSampler(
batch_size, dataset,
num_replicas=None, batch_size,
rank=None, num_replicas=None,
shuffle=is_training, rank=None,
drop_last=is_training, shuffle=is_training,
sortagrad=is_training, drop_last=is_training,
shuffle_method=shuffle_method) sortagrad=is_training,
shuffle_method=shuffle_method)
else:
batch_sampler = DeepSpeech2BatchSampler(
dataset,
shuffle=is_training,
batch_size=batch_size,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
""" """

@ -77,9 +77,9 @@ def infer():
"""Inference for DeepSpeech2.""" """Inference for DeepSpeech2."""
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
#check_cuda(args.use_gpu) check_cuda(args.use_gpu)
# check if paddlepaddle version is satisfied # check if paddlepaddle version is satisfied
#check_version() check_version()
# data_generator = DataGenerator( # data_generator = DataGenerator(
# vocab_filepath=args.vocab_path, # vocab_filepath=args.vocab_path,
@ -114,16 +114,32 @@ def infer():
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
for audio, text, audio_len, text_len in batch_reader: #for audio, text, audio_len, text_len in batch_reader:
print(audio.shape) # print(audio.shape)
print(text.shape) # print(text.shape)
print(audio_len) # print(audio_len)
print(text_len) # print(text_len)
break # break
infer_data = batch_reader() reader = batch_reader()
infer_data = reader.next()
print(infer_data) print(infer_data)
from model_utils.network2 import DeepSpeech2
feat_dim=161
model = DeepSpeech2(
feat_size=feat_dim,
dict_size=batch_reader.dataset.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
#rnn_size=1024,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights,
)
output = model(*infer_data)
print(output)
# ds2_model = DeepSpeech2Model( # ds2_model = DeepSpeech2Model(
# vocab_size=data_generator.vocab_size, # vocab_size=data_generator.vocab_size,
# num_conv_layers=args.num_conv_layers, # num_conv_layers=args.num_conv_layers,

@ -497,8 +497,6 @@ class DeepSpeech2(nn.Layer):
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights)
self.fc = nn.Linear(rnn_size * 2, dict_size + 1) self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
self.loss = nn.CTCLoss(blank=dict_size, reduction='none')
def predict(self, audio, audio_len): def predict(self, audio, audio_len):
# [B, D, T] -> [B, C=1, D, T] # [B, D, T] -> [B, C=1, D, T]
audio = audio.unsqueeze(1) audio = audio.unsqueeze(1)
@ -534,14 +532,24 @@ class DeepSpeech2(nn.Layer):
text_len: shape [B] text_len: shape [B]
""" """
logits, probs = self.predict(audio, audio_len) logits, probs = self.predict(audio, audio_len)
# warp-ctc do softmax on activations
# warp-ctc need activation with shape [T, B, V + 1]
logits = logits.transpose([1, 0, 2])
print(logits.shape) print(logits.shape)
print(text.shape) print(text.shape)
print(audio_len.shape) print(audio_len.shape)
print(text_len.shape) print(text_len.shape)
return logits
class DeepSpeechLoss(nn.Layer):
def __init__(self, vocab_size):
super().__init__()
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
def forward(self, logits, text, audio_len, text_len):
# warp-ctc do softmax on activations
# warp-ctc need activation with shape [T, B, V + 1]
logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, audio_len, text_len) ctc_loss = self.loss(logits, text, audio_len, text_len)
ctc_loss /= text_len # norm_by_times ctc_loss /= text_len # norm_by_times
ctc_loss = ctc_loss.sum() ctc_loss = ctc_loss.sum()
return probs, ctc_loss return ctc_loss

@ -1,4 +1,4 @@
scipy==1.2.1 scipy==1.2.1
resampy==0.1.5 resampy==0.2.2
SoundFile==0.9.0.post1 SoundFile==0.9.0.post1
python_speech_features python_speech_features

@ -6,6 +6,10 @@ else
SUDO='sudo' SUDO='sudo'
fi fi
if [ -e /etc/lsb-release ];then
${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
fi
# install python dependencies # install python dependencies
if [ -f "requirements.txt" ]; then if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt pip3 install -r requirements.txt
@ -18,9 +22,6 @@ fi
# install package libsndfile # install package libsndfile
python3 -c "import soundfile" python3 -c "import soundfile"
if [ $? != 0 ]; then if [ $? != 0 ]; then
if [ -e /etc/lsb-release ];then
${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
fi
echo "Install package libsndfile into default system path." echo "Install package libsndfile into default system path."
wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz"
if [ $? != 0 ]; then if [ $? != 0 ]; then

Loading…
Cancel
Save