diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 667b6fbe5..67c1b57ee 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import random import tarfile import numpy as np import paddle from paddle.io import Dataset from paddle.io import DataLoader +from paddle.io import BatchSampler from paddle.io import DistributedBatchSampler from collections import namedtuple from functools import partial @@ -170,7 +172,7 @@ class DeepSpeech2Dataset(Dataset): instance["text"]) -class DeepSpeech2BatchSampler(DistributedBatchSampler): +class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): def __init__(self, dataset, batch_size, @@ -279,6 +281,179 @@ class DeepSpeech2BatchSampler(DistributedBatchSampler): return num_samples // self.batch_size +class DeepSpeech2BatchSampler(BatchSampler): + def __init__(self, + dataset, + batch_size, + shuffle=False, + drop_last=False, + sortagrad=False, + shuffle_method="batch_shuffle", + num_replicas=1, + rank=0): + self.dataset = dataset + + assert isinstance(batch_size, int) and batch_size > 0, \ + "batch_size should be a positive integer" + self.batch_size = batch_size + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value" + self.shuffle = shuffle + assert isinstance(drop_last, bool), \ + "drop_last should be a boolean number" + + if num_replicas is not None: + assert isinstance(num_replicas, int) and num_replicas > 0, \ + "num_replicas should be a positive integer" + self.nranks = num_replicas + else: + self.nranks = num_replicas + + if rank is not None: + assert isinstance(rank, int) and rank >= 0, \ + "rank should be a non-negative integer" + self.local_rank = rank + else: + self.local_rank = rank + + self.drop_last = drop_last + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) + self.total_size = self.num_samples * self.nranks + self._sortagrad = sortagrad + self._shuffle_method = shuffle_method + + def _batch_shuffle(self, manifest, batch_size, clipped=False): + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. + 4. Shuffle the minibatches. + + :param manifest: Manifest contents. List of dict. + :type manifest: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool + :return: Batch shuffled mainifest. + :rtype: list + """ + rng = np.random.RandomState(self.epoch) + manifest.sort(key=lambda x: x["duration"]) + shift_len = rng.randint(0, batch_size - 1) + batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) + rng.shuffle(batch_manifest) + batch_manifest = [item for batch in batch_manifest for item in batch] + if not clipped: + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # sort (by duration) or batch-wise shuffle the manifest + if self.shuffle: + if self.epoch == 0 and self.sortagrad: + pass + else: + if self._shuffle_method == "batch_shuffle": + indices = self._batch_shuffle( + indices, self.batch_size, clipped=False) + elif self._shuffle_method == "instance_shuffle": + np.random.RandomState(self.epoch).shuffle(indices) + else: + raise ValueError("Unknown shuffle method %s." % + self._shuffle_method) + assert len(indices) == self.total_size + self.epoch += 1 + + # subsample + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend( + indices[self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return subsampled_indices + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + def __len__(self): + num_samples = self.num_samples + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size + + def set_epoch(self, epoch): + """ + Sets the epoch number. When :attr:`shuffle=True`, this number is used + as seeds of random numbers. By default, users may not set this, all + replicas (workers) use a different random ordering for each epoch. + If set same number at each epoch, this sampler will yield the same + ordering at all epoches. + Arguments: + epoch (int): Epoch number. + Examples: + .. code-block:: python + + import numpy as np + + from paddle.io import Dataset, DistributedBatchSampler + + # init with dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(100) + sampler = DistributedBatchSampler(dataset, batch_size=64) + + for epoch in range(10): + sampler.set_epoch(epoch) + """ + self.epoch = epoch + + def create_dataloader(manifest_path, vocab_filepath, mean_std_filepath, @@ -296,7 +471,8 @@ def create_dataloader(manifest_path, batch_size=1, num_workers=0, sortagrad=False, - shuffle_method=None): + shuffle_method=None, + dist=False): dataset = DeepSpeech2Dataset( manifest_path, @@ -313,15 +489,24 @@ def create_dataloader(manifest_path, random_seed=random_seed, keep_transcription_text=keep_transcription_text) - batch_sampler = DeepSpeech2BatchSampler( - dataset, - batch_size, - num_replicas=None, - rank=None, - shuffle=is_training, - drop_last=is_training, - sortagrad=is_training, - shuffle_method=shuffle_method) + if dist: + batch_sampler = DeepSpeech2DistributedBatchSampler( + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=is_training, + drop_last=is_training, + sortagrad=is_training, + shuffle_method=shuffle_method) + else: + batch_sampler = DeepSpeech2BatchSampler( + dataset, + shuffle=is_training, + batch_size=batch_size, + drop_last=is_training, + sortagrad=is_training, + shuffle_method=shuffle_method) def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): """ diff --git a/infer.py b/infer.py index 3c9171566..11a4ad7ab 100644 --- a/infer.py +++ b/infer.py @@ -77,9 +77,9 @@ def infer(): """Inference for DeepSpeech2.""" # check if set use_gpu=True in paddlepaddle cpu version - #check_cuda(args.use_gpu) + check_cuda(args.use_gpu) # check if paddlepaddle version is satisfied - #check_version() + check_version() # data_generator = DataGenerator( # vocab_filepath=args.vocab_path, @@ -114,16 +114,32 @@ def infer(): sortagrad=False, shuffle_method=None) - for audio, text, audio_len, text_len in batch_reader: - print(audio.shape) - print(text.shape) - print(audio_len) - print(text_len) - break + #for audio, text, audio_len, text_len in batch_reader: + # print(audio.shape) + # print(text.shape) + # print(audio_len) + # print(text_len) + # break - infer_data = batch_reader() + reader = batch_reader() + infer_data = reader.next() print(infer_data) + from model_utils.network2 import DeepSpeech2 + feat_dim=161 + model = DeepSpeech2( + feat_size=feat_dim, + dict_size=batch_reader.dataset.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + #rnn_size=1024, + use_gru=args.use_gru, + share_rnn_weights=args.share_rnn_weights, + ) + + output = model(*infer_data) + print(output) + # ds2_model = DeepSpeech2Model( # vocab_size=data_generator.vocab_size, # num_conv_layers=args.num_conv_layers, diff --git a/model_utils/network2.py b/model_utils/network2.py index 8cbbbf818..bab97a3cc 100644 --- a/model_utils/network2.py +++ b/model_utils/network2.py @@ -497,8 +497,6 @@ class DeepSpeech2(nn.Layer): share_rnn_weights=share_rnn_weights) self.fc = nn.Linear(rnn_size * 2, dict_size + 1) - self.loss = nn.CTCLoss(blank=dict_size, reduction='none') - def predict(self, audio, audio_len): # [B, D, T] -> [B, C=1, D, T] audio = audio.unsqueeze(1) @@ -534,14 +532,24 @@ class DeepSpeech2(nn.Layer): text_len: shape [B] """ logits, probs = self.predict(audio, audio_len) - # warp-ctc do softmax on activations - # warp-ctc need activation with shape [T, B, V + 1] - logits = logits.transpose([1, 0, 2]) print(logits.shape) print(text.shape) print(audio_len.shape) print(text_len.shape) + return logits + + +class DeepSpeechLoss(nn.Layer): + def __init__(self, vocab_size): + super().__init__() + self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') + + def forward(self, logits, text, audio_len, text_len): + # warp-ctc do softmax on activations + # warp-ctc need activation with shape [T, B, V + 1] + logits = logits.transpose([1, 0, 2]) + ctc_loss = self.loss(logits, text, audio_len, text_len) ctc_loss /= text_len # norm_by_times ctc_loss = ctc_loss.sum() - return probs, ctc_loss + return ctc_loss diff --git a/requirements.txt b/requirements.txt index 8c57208a6..af2993b6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ scipy==1.2.1 -resampy==0.1.5 +resampy==0.2.2 SoundFile==0.9.0.post1 python_speech_features diff --git a/setup.sh b/setup.sh index 3827dc1b3..21d9c19ec 100644 --- a/setup.sh +++ b/setup.sh @@ -6,6 +6,10 @@ else SUDO='sudo' fi +if [ -e /etc/lsb-release ];then + ${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev +fi + # install python dependencies if [ -f "requirements.txt" ]; then pip3 install -r requirements.txt @@ -18,9 +22,6 @@ fi # install package libsndfile python3 -c "import soundfile" if [ $? != 0 ]; then - if [ -e /etc/lsb-release ];then - ${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev - fi echo "Install package libsndfile into default system path." wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" if [ $? != 0 ]; then