diff --git a/deepspeech/decoders/swig/.gitignore b/deepspeech/decoders/swig/.gitignore new file mode 100644 index 000000000..0b1046ae8 --- /dev/null +++ b/deepspeech/decoders/swig/.gitignore @@ -0,0 +1,9 @@ +ThreadPool/ +build/ +dist/ +kenlm/ +openfst-1.6.3/ +openfst-1.6.3.tar.gz +swig_decoders.egg-info/ +decoders_wrap.cxx +swig_decoders.py diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 455f5b6c1..af635a68d 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -13,6 +13,7 @@ # limitations under the License. from yacs.config import CfgNode as CN +from deepspeech.models.DeepSpeech2 import DeepSpeech2Model _C = CN() _C.data = CN( @@ -50,6 +51,8 @@ _C.model = CN( share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. )) +DeepSpeech2Model.params(_C.model) + _C.training = CN( dict( lr=5e-4, # learning rate diff --git a/deepspeech/exps/deepspeech2/dataset.py b/deepspeech/exps/deepspeech2/dataset.py deleted file mode 100644 index 72e3d840d..000000000 --- a/deepspeech/exps/deepspeech2/dataset.py +++ /dev/null @@ -1,584 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import random -import tarfile -import logging -import numpy as np -from collections import namedtuple -from functools import partial - -import paddle -from paddle.io import Dataset -from paddle.io import DataLoader -from paddle.io import BatchSampler -from paddle.io import DistributedBatchSampler -from paddle import distributed as dist - -from deepspeech.frontend.utility import read_manifest -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline -from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from deepspeech.frontend.speech import SpeechSegment -from deepspeech.frontend.normalizer import FeatureNormalizer - -logger = logging.getLogger(__name__) - -__all__ = [ - "DeepSpeech2Dataset", - "DeepSpeech2DistributedBatchSampler", - "DeepSpeech2BatchSampler", - "SpeechCollator", -] - - -class DeepSpeech2Dataset(Dataset): - def __init__(self, - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - n_fft=None, - max_freq=None, - target_sample_rate=16000, - specgram_type='linear', - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False): - super().__init__() - - self._max_duration = max_duration - self._min_duration = min_duration - self._normalizer = FeatureNormalizer(mean_std_filepath) - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=augmentation_config, random_seed=random_seed) - self._speech_featurizer = SpeechFeaturizer( - vocab_filepath=vocab_filepath, - specgram_type=specgram_type, - stride_ms=stride_ms, - window_ms=window_ms, - n_fft=n_fft, - max_freq=max_freq, - target_sample_rate=target_sample_rate, - use_dB_normalization=use_dB_normalization, - target_dB=target_dB) - self._rng = random.Random(random_seed) - self._keep_transcription_text = keep_transcription_text - # for caching tar files info - self._local_data = namedtuple('local_data', ['tar2info', 'tar2object']) - self._local_data.tar2info = {} - self._local_data.tar2object = {} - - # read manifest - self._manifest = read_manifest( - manifest_path=manifest_path, - max_duration=self._max_duration, - min_duration=self._min_duration) - self._manifest.sort(key=lambda x: x["duration"]) - - @property - def manifest(self): - return self._manifest - - @property - def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return self._speech_featurizer.vocab_size - - @property - def vocab_list(self): - """Return the vocabulary in list. - - :return: Vocabulary in list. - :rtype: list - """ - return self._speech_featurizer.vocab_list - - @property - def feature_size(self): - return self._speech_featurizer.feature_size - - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - - def process_utterance(self, audio_file, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param transcript: Transcription text. - :type transcript: str - :return: Tuple of audio feature tensor and data of transcription part, - where transcription part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), transcript) - else: - speech_segment = SpeechSegment.from_file(audio_file, transcript) - self._augmentation_pipeline.transform_audio(speech_segment) - specgram, transcript_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - specgram = self._normalizer.apply(specgram) - return specgram, transcript_part - - def _instance_reader_creator(self, manifest): - """ - Instance reader creator. Create a callable function to produce - instances of data. - - Instance: a tuple of ndarray of audio spectrogram and a list of - token indices for transcript. - """ - - def reader(): - for instance in manifest: - inst = self.process_utterance(instance["audio_filepath"], - instance["text"]) - yield inst - - return reader - - def __len__(self): - return len(self._manifest) - - def __getitem__(self, idx): - instance = self._manifest[idx] - return self.process_utterance(instance["audio_filepath"], - instance["text"]) - - -class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): - def __init__(self, - dataset, - batch_size, - num_replicas=None, - rank=None, - shuffle=False, - drop_last=False, - sortagrad=False, - shuffle_method="batch_shuffle"): - super().__init__(dataset, batch_size, num_replicas, rank, shuffle, - drop_last) - self._sortagrad = sortagrad - self._shuffle_method = shuffle_method - - def _batch_shuffle(self, indices, batch_size, clipped=False): - """Put similarly-sized instances into minibatches for better efficiency - and make a batch-wise shuffle. - - 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly shift `k` instances in order to create different batches - for different epochs. Create minibatches. - 4. Shuffle the minibatches. - - :param indices: indexes. List of int. - :type indices: list - :param batch_size: Batch size. This size is also used for generate - a random number for batch shuffle. - :type batch_size: int - :param clipped: Whether to clip the heading (small shift) and trailing - (incomplete batch) instances. - :type clipped: bool - :return: Batch shuffled mainifest. - :rtype: list - """ - rng = np.random.RandomState(self.epoch) - shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) - rng.shuffle(batch_indices) - batch_indices = [item for batch in batch_indices for item in batch] - assert (clipped == False) - if not clipped: - res_len = len(indices) - shift_len - len(batch_indices) - # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) - if res_len != 0: - batch_indices.extend(indices[-res_len:]) - batch_indices.extend(indices[0:shift_len]) - assert len(indices) == len( - batch_indices - ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" - return batch_indices - - def __iter__(self): - num_samples = len(self.dataset) - indices = np.arange(num_samples).tolist() - indices += indices[:(self.total_size - len(indices))] - assert len(indices) == self.total_size - - # sort (by duration) or batch-wise shuffle the manifest - if self.shuffle: - if self.epoch == 0 and self._sortagrad: - logger.info( - f'rank: {dist.get_rank()} dataset sortagrad! epoch {self.epoch}' - ) - else: - logger.info( - f'rank: {dist.get_rank()} dataset shuffle! epoch {self.epoch}' - ) - if self._shuffle_method == "batch_shuffle": - indices = self._batch_shuffle( - indices, self.batch_size, clipped=False) - elif self._shuffle_method == "instance_shuffle": - np.random.RandomState(self.epoch).shuffle(indices) - else: - raise ValueError("Unknown shuffle method %s." % - self._shuffle_method) - assert len( - indices - ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" - - # subsample - def _get_indices_by_batch_size(indices): - subsampled_indices = [] - last_batch_size = self.total_size % (self.batch_size * self.nranks) - assert last_batch_size % self.nranks == 0 - last_local_batch_size = last_batch_size // self.nranks - - for i in range(self.local_rank * self.batch_size, - len(indices) - last_batch_size, - self.batch_size * self.nranks): - subsampled_indices.extend(indices[i:i + self.batch_size]) - - indices = indices[len(indices) - last_batch_size:] - subsampled_indices.extend( - indices[self.local_rank * last_local_batch_size:( - self.local_rank + 1) * last_local_batch_size]) - return subsampled_indices - - if self.nranks > 1: - indices = _get_indices_by_batch_size(indices) - - assert len(indices) == self.num_samples - _sample_iter = iter(indices) - - batch_indices = [] - for idx in _sample_iter: - batch_indices.append(idx) - if len(batch_indices) == self.batch_size: - logger.info( - f"rank: {dist.get_rank()} batch index: {batch_indices} ") - yield batch_indices - batch_indices = [] - if not self.drop_last and len(batch_indices) > 0: - yield batch_indices - - def __len__(self): - num_samples = self.num_samples - num_samples += int(not self.drop_last) * (self.batch_size - 1) - return num_samples // self.batch_size - - -class DeepSpeech2BatchSampler(BatchSampler): - def __init__(self, - dataset, - batch_size, - shuffle=False, - drop_last=False, - sortagrad=False, - shuffle_method="batch_shuffle"): - self.dataset = dataset - - assert isinstance(batch_size, int) and batch_size > 0, \ - "batch_size should be a positive integer" - self.batch_size = batch_size - assert isinstance(shuffle, bool), \ - "shuffle should be a boolean value" - self.shuffle = shuffle - assert isinstance(drop_last, bool), \ - "drop_last should be a boolean number" - - self.drop_last = drop_last - self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0)) - self.total_size = self.num_samples - self._sortagrad = sortagrad - self._shuffle_method = shuffle_method - - def _batch_shuffle(self, indices, batch_size, clipped=False): - """Put similarly-sized instances into minibatches for better efficiency - and make a batch-wise shuffle. - - 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly shift `k` instances in order to create different batches - for different epochs. Create minibatches. - 4. Shuffle the minibatches. - - :param indices: indexes. List of int. - :type indices: list - :param batch_size: Batch size. This size is also used for generate - a random number for batch shuffle. - :type batch_size: int - :param clipped: Whether to clip the heading (small shift) and trailing - (incomplete batch) instances. - :type clipped: bool - :return: Batch shuffled mainifest. - :rtype: list - """ - rng = np.random.RandomState(self.epoch) - # must shift at leat by one - shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) - rng.shuffle(batch_indices) - batch_indices = [item for batch in batch_indices for item in batch] - assert (clipped == False) - if not clipped: - res_len = len(indices) - shift_len - len(batch_indices) - # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) - if res_len != 0: - batch_indices.extend(indices[-res_len:]) - batch_indices.extend(indices[0:shift_len]) - assert len(indices) == len( - batch_indices - ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" - return batch_indices - - def __iter__(self): - num_samples = len(self.dataset) - indices = np.arange(num_samples).tolist() - indices += indices[:(self.total_size - len(indices))] - assert len(indices) == self.total_size - - # sort (by duration) or batch-wise shuffle the manifest - if self.shuffle: - if self.epoch == 0 and self._sortagrad: - logger.info(f'dataset sortagrad! epoch {self.epoch}') - else: - logger.info(f'dataset shuffle! epoch {self.epoch}') - if self._shuffle_method == "batch_shuffle": - indices = self._batch_shuffle( - indices, self.batch_size, clipped=False) - elif self._shuffle_method == "instance_shuffle": - np.random.RandomState(self.epoch).shuffle(indices) - else: - raise ValueError("Unknown shuffle method %s." % - self._shuffle_method) - assert len( - indices - ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" - - assert len(indices) == self.num_samples - _sample_iter = iter(indices) - - batch_indices = [] - for idx in _sample_iter: - batch_indices.append(idx) - if len(batch_indices) == self.batch_size: - logger.info( - f"rank: {dist.get_rank()} batch index: {batch_indices} ") - yield batch_indices - batch_indices = [] - if not self.drop_last and len(batch_indices) > 0: - yield batch_indices - - self.epoch += 1 - - def __len__(self): - num_samples = self.num_samples - num_samples += int(not self.drop_last) * (self.batch_size - 1) - return num_samples // self.batch_size - - -class SpeechCollator(): - def __init__(self, padding_to=-1, is_training=True): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. - - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - """ - self._padding_to = padding_to - self._is_training = is_training - - def __call__(self, batch): - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, _ in batch]) - if self._padding_to != -1: - if self._padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = self._padding_to - max_text_length = max([len(text) for _, text in batch]) - # padding - padded_audios = [] - audio_lens = [] - texts, text_lens = [], [] - for audio, text in batch: - # audio - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - padded_audios.append(padded_audio) - audio_lens.append(audio.shape[1]) - # text - padded_text = np.zeros([max_text_length]) - if self._is_training: - padded_text[:len(text)] = text #ids - else: - padded_text[:len(text)] = [ord(t) for t in text] # string - texts.append(padded_text) - text_lens.append(len(text)) - - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - - -def create_dataloader(manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - specgram_type='linear', - use_dB_normalization=True, - random_seed=0, - keep_transcription_text=False, - is_training=False, - batch_size=1, - num_workers=0, - sortagrad=False, - shuffle_method=None, - dist=False): - - dataset = DeepSpeech2Dataset( - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config=augmentation_config, - max_duration=max_duration, - min_duration=min_duration, - stride_ms=stride_ms, - window_ms=window_ms, - max_freq=max_freq, - specgram_type=specgram_type, - use_dB_normalization=use_dB_normalization, - random_seed=random_seed, - keep_transcription_text=keep_transcription_text) - - if dist: - batch_sampler = DeepSpeech2DistributedBatchSampler( - dataset, - batch_size, - num_replicas=None, - rank=None, - shuffle=is_training, - drop_last=is_training, - sortagrad=is_training, - shuffle_method=shuffle_method) - else: - batch_sampler = DeepSpeech2BatchSampler( - dataset, - shuffle=is_training, - batch_size=batch_size, - drop_last=is_training, - sortagrad=is_training, - shuffle_method=shuffle_method) - - def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. - - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - - If `flatten` is True, features will be flatten to 1darray. - """ - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, text in batch]) - if padding_to != -1: - if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = padding_to - max_text_length = max([len(text) for audio, text in batch]) - # padding - padded_audios = [] - audio_lens = [] - texts, text_lens = [], [] - for audio, text in batch: - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - if flatten: - padded_audio = padded_audio.flatten() - padded_audios.append(padded_audio) - audio_lens.append(audio.shape[1]) - - padded_text = np.zeros([max_text_length]) - if is_training: - padded_text[:len(text)] = text #ids - else: - padded_text[:len(text)] = [ord(t) for t in text] # string - texts.append(padded_text) - text_lens.append(len(text)) - - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=partial(padding_batch, is_training=is_training), - num_workers=num_workers) - return loader diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 633569fcf..6f1d90ca8 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -27,95 +27,28 @@ import paddle from paddle import distributed as dist from paddle.io import DataLoader -from paddle.fluid.dygraph import base as imperative_base -from paddle.fluid import layers -from paddle.fluid import core - from deepspeech.training import Trainer -from deepspeech.utils import mp_tools -from deepspeech.utils.error_rate import char_errors, word_errors, cer, wer +from deepspeech.training.gradclip import MyClipGradByGlobalNorm -from deepspeech.models.network import DeepSpeech2 -from deepspeech.models.network import DeepSpeech2Loss +from deepspeech.utils import mp_tools +from deepspeech.utils.error_rate import char_errors +from deepspeech.utils.error_rate import word_errors +from deepspeech.utils.error_rate import cer +from deepspeech.utils.error_rate import wer +from deepspeech.utils.utility import print_grads +from deepspeech.utils.utility import print_params -from deepspeech.decoders.swig_wrapper import Scorer -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.dataset import ManifestDataset -from deepspeech.exps.deepspeech2.dataset import SpeechCollator -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2Dataset -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2DistributedBatchSampler -from deepspeech.exps.deepspeech2.dataset import DeepSpeech2BatchSampler +from deepspeech.training.loss import CTCLoss +from deepspeech.models.DeepSpeech2 import DeepSpeech2Model logger = logging.getLogger(__name__) -class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm): - def __init__(self, clip_norm): - super().__init__(clip_norm) - - @imperative_base.no_grad - def _dygraph_clip(self, params_grads): - params_and_grads = [] - sum_square_list = [] - for p, g in params_grads: - if g is None: - continue - if getattr(p, 'need_clip', True) is False: - continue - merge_grad = g - if g.type == core.VarDesc.VarType.SELECTED_ROWS: - merge_grad = layers.merge_selected_rows(g) - merge_grad = layers.get_tensor_from_selected_rows(merge_grad) - square = layers.square(merge_grad) - sum_square = layers.reduce_sum(square) - logger.info( - f"Grad Before Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }" - ) - sum_square_list.append(sum_square) - - # all parameters have been filterd out - if len(sum_square_list) == 0: - return params_grads - - global_norm_var = layers.concat(sum_square_list) - global_norm_var = layers.reduce_sum(global_norm_var) - global_norm_var = layers.sqrt(global_norm_var) - logger.info(f"Grad Global Norm: {float(global_norm_var)}!!!!") - max_global_norm = layers.fill_constant( - shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) - clip_var = layers.elementwise_div( - x=max_global_norm, - y=layers.elementwise_max(x=global_norm_var, y=max_global_norm)) - for p, g in params_grads: - if g is None: - continue - if getattr(p, 'need_clip', True) is False: - params_and_grads.append((p, g)) - continue - new_grad = layers.elementwise_mul(x=g, y=clip_var) - logger.info( - f"Grad After Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }" - ) - params_and_grads.append((p, new_grad)) - - return params_and_grads - - -def print_grads(model, logger=None): - for n, p in model.named_parameters(): - msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" - if logger: - logger.info(msg) - - -def print_params(model, logger=None): - for n, p in model.named_parameters(): - msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}" - if logger: - logger.info(msg) - - class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) @@ -193,7 +126,7 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config - model = DeepSpeech2( + model = DeepSpeech2Model( feat_size=self.train_loader.dataset.feature_size, dict_size=self.train_loader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, @@ -219,7 +152,7 @@ class DeepSpeech2Trainer(Trainer): config.training.weight_decay), grad_clip=grad_clip) - criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) + criterion = CTCLoss(self.train_loader.dataset.vocab_size) self.model = model self.optimizer = optimizer @@ -230,7 +163,7 @@ class DeepSpeech2Trainer(Trainer): def setup_dataloader(self): config = self.config - train_dataset = DeepSpeech2Dataset( + train_dataset = ManifestDataset( config.data.train_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, @@ -250,7 +183,7 @@ class DeepSpeech2Trainer(Trainer): random_seed=config.data.random_seed, keep_transcription_text=False) - dev_dataset = DeepSpeech2Dataset( + dev_dataset = ManifestDataset( config.data.dev_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, @@ -269,7 +202,7 @@ class DeepSpeech2Trainer(Trainer): keep_transcription_text=False) if self.parallel: - batch_sampler = DeepSpeech2DistributedBatchSampler( + batch_sampler = SortagradDistributedBatchSampler( train_dataset, batch_size=config.data.batch_size, num_replicas=None, @@ -279,7 +212,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) else: - batch_sampler = DeepSpeech2BatchSampler( + batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, batch_size=config.data.batch_size, @@ -461,7 +394,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_model(self): config = self.config - model = DeepSpeech2( + model = DeepSpeech2Model( feat_size=self.test_loader.dataset.feature_size, dict_size=self.test_loader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, @@ -473,7 +406,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): if self.parallel: model = paddle.DataParallel(model) - criterion = DeepSpeech2Loss(self.test_loader.dataset.vocab_size) + criterion = CTCLoss(self.test_loader.dataset.vocab_size) self.model = model self.criterion = criterion @@ -482,7 +415,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup_dataloader(self): config = self.config # return raw text - test_dataset = DeepSpeech2Dataset( + test_dataset = ManifestDataset( config.data.test_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, diff --git a/deepspeech/io/__init__.py b/deepspeech/io/__init__.py new file mode 100644 index 000000000..12e1d4d91 --- /dev/null +++ b/deepspeech/io/__init__.py @@ -0,0 +1,128 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.io import DataLoader + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.dataset import ManifestDataset + + +def create_dataloader(manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + specgram_type='linear', + use_dB_normalization=True, + random_seed=0, + keep_transcription_text=False, + is_training=False, + batch_size=1, + num_workers=0, + sortagrad=False, + shuffle_method=None, + dist=False): + + dataset = ManifestDataset( + manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config=augmentation_config, + max_duration=max_duration, + min_duration=min_duration, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + specgram_type=specgram_type, + use_dB_normalization=use_dB_normalization, + random_seed=random_seed, + keep_transcription_text=keep_transcription_text) + + if dist: + batch_sampler = SortagradDistributedBatchSampler( + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=is_training, + drop_last=is_training, + sortagrad=is_training, + shuffle_method=shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + dataset, + shuffle=is_training, + batch_size=batch_size, + drop_last=is_training, + sortagrad=is_training, + shuffle_method=shuffle_method) + + def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): + """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. + + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). + + If `flatten` is True, features will be flatten to 1darray. + """ + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") + max_length = padding_to + max_text_length = max([len(text) for audio, text in batch]) + # padding + padded_audios = [] + audio_lens = [] + texts, text_lens = [], [] + for audio, text in batch: + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + if flatten: + padded_audio = padded_audio.flatten() + padded_audios.append(padded_audio) + audio_lens.append(audio.shape[1]) + + padded_text = np.zeros([max_text_length]) + if is_training: + padded_text[:len(text)] = text #ids + else: + padded_text[:len(text)] = [ord(t) for t in text] # string + texts.append(padded_text) + text_lens.append(len(text)) + + padded_audios = np.array(padded_audios).astype('float32') + audio_lens = np.array(audio_lens).astype('int64') + texts = np.array(texts).astype('int32') + text_lens = np.array(text_lens).astype('int64') + return padded_audios, texts, audio_lens, text_lens + + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=partial(padding_batch, is_training=is_training), + num_workers=num_workers) + return loader diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py new file mode 100644 index 000000000..9e50e170e --- /dev/null +++ b/deepspeech/io/collator.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from collections import namedtuple + +logger = logging.getLogger(__name__) + +__all__ = [ + "SpeechCollator", +] + + +class SpeechCollator(): + def __init__(self, padding_to=-1, is_training=True): + """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. + + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). + """ + self._padding_to = padding_to + self._is_training = is_training + + def __call__(self, batch): + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, _ in batch]) + if self._padding_to != -1: + if self._padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") + max_length = self._padding_to + max_text_length = max([len(text) for _, text in batch]) + # padding + padded_audios = [] + audio_lens = [] + texts, text_lens = [], [] + for audio, text in batch: + # audio + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + padded_audios.append(padded_audio) + audio_lens.append(audio.shape[1]) + # text + padded_text = np.zeros([max_text_length]) + if self._is_training: + padded_text[:len(text)] = text #ids + else: + padded_text[:len(text)] = [ord(t) for t in text] # string + texts.append(padded_text) + text_lens.append(len(text)) + + padded_audios = np.array(padded_audios).astype('float32') + audio_lens = np.array(audio_lens).astype('int64') + texts = np.array(texts).astype('int32') + text_lens = np.array(text_lens).astype('int64') + return padded_audios, texts, audio_lens, text_lens diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py new file mode 100644 index 000000000..3762f0d93 --- /dev/null +++ b/deepspeech/io/dataset.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random +import tarfile +import logging +import numpy as np +from collections import namedtuple +from functools import partial + +from paddle.io import Dataset + +from deepspeech.frontend.utility import read_manifest +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.speech import SpeechSegment +from deepspeech.frontend.normalizer import FeatureNormalizer + +logger = logging.getLogger(__name__) + +__all__ = [ + "ManifestDataset", +] + + +class ManifestDataset(Dataset): + def __init__(self, + manifest_path, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + n_fft=None, + max_freq=None, + target_sample_rate=16000, + specgram_type='linear', + use_dB_normalization=True, + target_dB=-20, + random_seed=0, + keep_transcription_text=False): + super().__init__() + + self._max_duration = max_duration + self._min_duration = min_duration + self._normalizer = FeatureNormalizer(mean_std_filepath) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=augmentation_config, random_seed=random_seed) + self._speech_featurizer = SpeechFeaturizer( + vocab_filepath=vocab_filepath, + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + n_fft=n_fft, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) + self._rng = random.Random(random_seed) + self._keep_transcription_text = keep_transcription_text + # for caching tar files info + self._local_data = namedtuple('local_data', ['tar2info', 'tar2object']) + self._local_data.tar2info = {} + self._local_data.tar2object = {} + + # read manifest + self._manifest = read_manifest( + manifest_path=manifest_path, + max_duration=self._max_duration, + min_duration=self._min_duration) + self._manifest.sort(key=lambda x: x["duration"]) + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ + return self._speech_featurizer.vocab_list + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + def _parse_tar(self, file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + def _subfile_from_tar(self, file): + """Get subfile object from tar. + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) + + def process_utterance(self, audio_file, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param transcript: Transcription text. + :type transcript: str + :return: Tuple of audio feature tensor and data of transcription part, + where transcription part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), transcript) + else: + speech_segment = SpeechSegment.from_file(audio_file, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, transcript_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + specgram = self._normalizer.apply(specgram) + return specgram, transcript_part + + def _instance_reader_creator(self, manifest): + """ + Instance reader creator. Create a callable function to produce + instances of data. + + Instance: a tuple of ndarray of audio spectrogram and a list of + token indices for transcript. + """ + + def reader(): + for instance in manifest: + inst = self.process_utterance(instance["audio_filepath"], + instance["text"]) + yield inst + + return reader + + def __len__(self): + return len(self._manifest) + + def __getitem__(self, idx): + instance = self._manifest[idx] + return self.process_utterance(instance["audio_filepath"], + instance["text"]) diff --git a/deepspeech/io/sampler.py b/deepspeech/io/sampler.py new file mode 100644 index 000000000..a08a3983d --- /dev/null +++ b/deepspeech/io/sampler.py @@ -0,0 +1,266 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random +import tarfile +import logging +import numpy as np +from collections import namedtuple +from functools import partial + +import paddle +from paddle.io import BatchSampler +from paddle.io import DistributedBatchSampler +from paddle import distributed as dist + +logger = logging.getLogger(__name__) + +__all__ = [ + "SortagradDistributedBatchSampler", + "SortagradBatchSampler", +] + + +class SortagradDistributedBatchSampler(DistributedBatchSampler): + def __init__(self, + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False, + sortagrad=False, + shuffle_method="batch_shuffle"): + super().__init__(dataset, batch_size, num_replicas, rank, shuffle, + drop_last) + self._sortagrad = sortagrad + self._shuffle_method = shuffle_method + + def _batch_shuffle(self, indices, batch_size, clipped=False): + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. + 4. Shuffle the minibatches. + + :param indices: indexes. List of int. + :type indices: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool + :return: Batch shuffled mainifest. + :rtype: list + """ + rng = np.random.RandomState(self.epoch) + shift_len = rng.randint(0, batch_size - 1) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + rng.shuffle(batch_indices) + batch_indices = [item for batch in batch_indices for item in batch] + assert (clipped == False) + if not clipped: + res_len = len(indices) - shift_len - len(batch_indices) + # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) + if res_len != 0: + batch_indices.extend(indices[-res_len:]) + batch_indices.extend(indices[0:shift_len]) + assert len(indices) == len( + batch_indices + ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" + return batch_indices + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # sort (by duration) or batch-wise shuffle the manifest + if self.shuffle: + if self.epoch == 0 and self._sortagrad: + logger.info( + f'rank: {dist.get_rank()} dataset sortagrad! epoch {self.epoch}' + ) + else: + logger.info( + f'rank: {dist.get_rank()} dataset shuffle! epoch {self.epoch}' + ) + if self._shuffle_method == "batch_shuffle": + indices = self._batch_shuffle( + indices, self.batch_size, clipped=False) + elif self._shuffle_method == "instance_shuffle": + np.random.RandomState(self.epoch).shuffle(indices) + else: + raise ValueError("Unknown shuffle method %s." % + self._shuffle_method) + assert len( + indices + ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" + + # subsample + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend( + indices[self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return subsampled_indices + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + logger.info( + f"rank: {dist.get_rank()} batch index: {batch_indices} ") + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + def __len__(self): + num_samples = self.num_samples + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size + + +class SortagradBatchSampler(BatchSampler): + def __init__(self, + dataset, + batch_size, + shuffle=False, + drop_last=False, + sortagrad=False, + shuffle_method="batch_shuffle"): + self.dataset = dataset + + assert isinstance(batch_size, int) and batch_size > 0, \ + "batch_size should be a positive integer" + self.batch_size = batch_size + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value" + self.shuffle = shuffle + assert isinstance(drop_last, bool), \ + "drop_last should be a boolean number" + + self.drop_last = drop_last + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0)) + self.total_size = self.num_samples + self._sortagrad = sortagrad + self._shuffle_method = shuffle_method + + def _batch_shuffle(self, indices, batch_size, clipped=False): + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. + 4. Shuffle the minibatches. + + :param indices: indexes. List of int. + :type indices: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool + :return: Batch shuffled mainifest. + :rtype: list + """ + rng = np.random.RandomState(self.epoch) + # must shift at leat by one + shift_len = rng.randint(0, batch_size - 1) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + rng.shuffle(batch_indices) + batch_indices = [item for batch in batch_indices for item in batch] + assert (clipped == False) + if not clipped: + res_len = len(indices) - shift_len - len(batch_indices) + # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) + if res_len != 0: + batch_indices.extend(indices[-res_len:]) + batch_indices.extend(indices[0:shift_len]) + assert len(indices) == len( + batch_indices + ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" + return batch_indices + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # sort (by duration) or batch-wise shuffle the manifest + if self.shuffle: + if self.epoch == 0 and self._sortagrad: + logger.info(f'dataset sortagrad! epoch {self.epoch}') + else: + logger.info(f'dataset shuffle! epoch {self.epoch}') + if self._shuffle_method == "batch_shuffle": + indices = self._batch_shuffle( + indices, self.batch_size, clipped=False) + elif self._shuffle_method == "instance_shuffle": + np.random.RandomState(self.epoch).shuffle(indices) + else: + raise ValueError("Unknown shuffle method %s." % + self._shuffle_method) + assert len( + indices + ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + logger.info( + f"rank: {dist.get_rank()} batch index: {batch_indices} ") + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + self.epoch += 1 + + def __len__(self): + num_samples = self.num_samples + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py new file mode 100644 index 000000000..4a746195d --- /dev/null +++ b/deepspeech/models/deepspeech2.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import collections +import numpy as np +import logging +from typing import Optional +from yacs.config import CfgNode + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.conv import ConvStack +from deepspeech.modules.conv import RNNStack +from deepspeech.modules.mask import sequence_mask +from deepspeech.modules.activation import brelu +from deepspeech.utils import checkpoint +from deepspeech.decoders.swig_wrapper import Scorer +from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder +from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch + +logger = logging.getLogger(__name__) + +__all__ = ['DeepSpeech2Model'] + + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + if config is not None: + config.model.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + self.fc = nn.Linear(rnn_size * 2, dict_size + 1) + + self.logger = logging.getLogger(__name__) + self._ext_scorer = None + + def infer(self, audio, audio_len): + # [B, D, T] -> [B, C=1, D, T] + audio = audio.unsqueeze(1) + + # convolution group + x, audio_len = self.conv(audio, audio_len) + #print('conv out', x.shape) + + # convert data from convolution feature map to sequence of vectors + B, C, D, T = paddle.shape(x) + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + x = x.reshape([B, T, C * D]) #[B, T, C*D] + #print('rnn input', x.shape) + + # remove padding part + x, audio_len = self.rnn(x, audio_len) #[B, T, D] + #print('rnn output', x.shape) + + logits = self.fc(x) #[B, T, V + 1] + + #ctcdecoder need probs, not log_probs + probs = F.softmax(logits) + + return logits, probs, audio_len + + def forward(self, audio, text, audio_len, text_len): + """ + audio: shape [B, D, T] + text: shape [B, T] + audio_len: shape [B] + text_len: shape [B] + """ + return self.infer(audio, audio_len) + + @paddle.no_grad() + def predict(self, audio, audio_len): + """ Model infer """ + return self.infer(audio, audio_len) + + def _decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :return: List of transcription texts. + :rtype: List of str + """ + results = [] + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=vocab_list) + results.append(output_transcription) + return results + + def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, + vocab_list): + """Initialize the external scorer. + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param language_model_path: Filepath for language model. If it is + empty, the external scorer will be set to + None, and the decoding method will be pure + beam search without scorer. + :type language_model_path: str|None + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + """ + # init once + if self._ext_scorer != None: + return + + if language_model_path != '': + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer") + else: + self._ext_scorer = None + self.logger.info("no language model provided, " + "decoding by pure beam search without scorer.") + + def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of str + """ + if self._ext_scorer != None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + + # beam search decode + num_processes = min(num_processes, len(probs_split)) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) + + results = [result[0][1] for result in beam_search_results] + return results + + def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, + decoding_method): + if decoding_method == "ctc_beam_search": + self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, + vocab_list) + + def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + """ probs: activation after softmax + logits_len: audio output lens + """ + probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] + if decoding_method == "ctc_greedy": + result_transcripts = self._decode_batch_greedy( + probs_split=probs_split, vocab_list=vocab_list) + elif decoding_method == "ctc_beam_search": + result_transcripts = self._decode_batch_beam_search( + probs_split=probs_split, + beam_alpha=beam_alpha, + beam_beta=beam_beta, + beam_size=beam_size, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n, + vocab_list=vocab_list, + num_processes=num_processes) + else: + raise ValueError(f"Not support: {decoding_method}") + return result_transcripts + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + _, probs, logits_lens = self.predict(audio, audio_len) + return self.decode_probs(probs.numpy(), logits_lens, vocab_list, + decoding_method, lang_model_path, beam_alpha, + beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + def from_pretrained(self, checkpoint_path): + """Build a model from a pretrained model. + Parameters + ---------- + model: nn.Layer + Asr Model. + + checkpoint_path: Path or str + The path of pretrained model checkpoint, without extension name. + + Returns + ------- + Model + The model build from pretrined result. + """ + checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) + return diff --git a/deepspeech/models/network.py b/deepspeech/models/network.py deleted file mode 100644 index a3ea771dc..000000000 --- a/deepspeech/models/network.py +++ /dev/null @@ -1,754 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import collections -import numpy as np -import logging - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -from deepspeech.utils import checkpoint -from deepspeech.decoders.swig_wrapper import Scorer -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch - -logger = logging.getLogger(__name__) - -__all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] - - -def brelu(x, t_min=0.0, t_max=24.0, name=None): - t_min = paddle.to_tensor(t_min) - t_max = paddle.to_tensor(t_max) - return x.maximum(t_min).minimum(t_max) - - -def sequence_mask(x_len, max_len=None, dtype='float32'): - max_len = max_len or x_len.max() - x_len = paddle.unsqueeze(x_len, -1) - row_vector = paddle.arange(max_len) - #mask = row_vector < x_len - mask = row_vector > x_len # a bug, broadcast 的时候出错了 - mask = paddle.cast(mask, dtype) - return mask - - -class ConvBn(nn.Layer): - """Convolution layer with batch normalization. - - :param kernel_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. - :type kernel_size: int|tuple|list - :param num_channels_in: Number of input channels. - :type num_channels_in: int - :param num_channels_out: Number of output channels. - :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. - :type stride: int|tuple|list - :param padding: The x dimension of the padding. Or input a tuple for two - image dimension. - :type padding: int|tuple|list - :param act: Activation type, relu|brelu - :type act: string - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param name: Name of the layer. - :param name: string - :return: Batch norm layer after convolution layer. - :rtype: Variable - - """ - - def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, - padding, act): - - super().__init__() - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - - self.conv = nn.Conv2D( - num_channels_in, - num_channels_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_attr=None, - bias_attr=False, - data_format='NCHW') - - self.bn = nn.BatchNorm2D( - num_channels_out, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - self.act = F.relu if act == 'relu' else brelu - - def forward(self, x, x_len): - """ - x(Tensor): audio, shape [B, C, D, T] - """ - x = self.conv(x) - x = self.bn(x) - x = self.act(x) - - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 - - # reset padding part to 0 - masks = sequence_mask(x_len) #[B, T] - masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - x = x.multiply(masks) - - return x, x_len - - -class ConvStack(nn.Layer): - """Convolution group with stacked convolution layers. - - :param feat_size: audio feature dim. - :type feat_size: int - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - """ - - def __init__(self, feat_size, num_stacks): - super().__init__() - self.feat_size = feat_size # D - self.num_stacks = num_stacks - - self.conv_in = ConvBn( - num_channels_in=1, - num_channels_out=32, - kernel_size=(41, 11), #[D, T] - stride=(2, 3), - padding=(20, 5), - act='brelu') - - out_channel = 32 - self.conv_stack = nn.LayerList([ - ConvBn( - num_channels_in=32, - num_channels_out=out_channel, - kernel_size=(21, 11), - stride=(2, 1), - padding=(10, 5), - act='brelu') for i in range(num_stacks - 1) - ]) - - # conv output feat_dim - output_height = (feat_size - 1) // 2 + 1 - for i in range(self.num_stacks - 1): - output_height = (output_height - 1) // 2 + 1 - self.output_height = out_channel * output_height - - def forward(self, x, x_len): - """ - x: shape [B, C, D, T] - x_len : shape [B] - """ - x, x_len = self.conv_in(x, x_len) - for i, conv in enumerate(self.conv_stack): - x, x_len = conv(x, x_len) - return x, x_len - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - # self.bias_ih = self.create_parameter( - # (hidden_size, ), - # bias_ih_attr, - # is_bias=True, - # default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCellShare(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size, - hidden_size, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - # self.bias_ih = self.create_parameter( - # (3 * hidden_size, ), - # bias_ih_attr, - # is_bias=True, - # default_initializer=I.Uniform(-std, std)) - self.bias_ih = None - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.tanh - #self._activation = F.relu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h_gates = h_gates + self.bias_hh - - x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) - h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) - - r = self._gate_activation(x_r + h_r) - z = self._gate_activation(x_z + h_z) - c = self._activation(x_c + r * h_c) # apply reset gate after mm - h = (pre_hidden - c) * z + c - # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru - - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer parameters. - :type name: string - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, share_weights): - super().__init__() - self.share_weights = share_weights - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - h_size, bias_attr=None, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, act): - super().__init__() - hidden_size = h_size * 3 - - self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.fw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) - self.bw_bn = nn.BatchNorm1D( - hidden_size, bias_attr=None, data_format='NLC') - - self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights): - super().__init__() - self.rnn_stacks = nn.LayerList() - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - self.rnn_stacks.append( - BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu")) - else: - self.rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - def forward(self, x, x_len): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - masks = sequence_mask(x_len) #[B, T] - masks = masks.unsqueeze(-1) # [B, T, 1] - x = x.multiply(masks) - return x, x_len - - -class DeepSpeech2(nn.Layer): - """The DeepSpeech2 network structure. - - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True): - super().__init__() - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - self.fc = nn.Linear(rnn_size * 2, dict_size + 1) - - self.logger = logging.getLogger(__name__) - self._ext_scorer = None - - def infer(self, audio, audio_len): - # [B, D, T] -> [B, C=1, D, T] - audio = audio.unsqueeze(1) - - # convolution group - x, audio_len = self.conv(audio, audio_len) - #print('conv out', x.shape) - - # convert data from convolution feature map to sequence of vectors - B, C, D, T = paddle.shape(x) - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - x = x.reshape([B, T, C * D]) #[B, T, C*D] - #print('rnn input', x.shape) - - # remove padding part - x, audio_len = self.rnn(x, audio_len) #[B, T, D] - #print('rnn output', x.shape) - - logits = self.fc(x) #[B, T, V + 1] - - #ctcdecoder need probs, not log_probs - probs = F.softmax(logits) - - return logits, probs, audio_len - - def forward(self, audio, text, audio_len, text_len): - """ - audio: shape [B, D, T] - text: shape [B, T] - audio_len: shape [B] - text_len: shape [B] - """ - return self.infer(audio, audio_len) - - @paddle.no_grad() - def predict(self, audio, audio_len): - """ Model infer """ - return self.infer(audio, audio_len) - - def _decode_batch_greedy(self, probs_split, vocab_list): - """Decode by best path for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :return: List of transcription texts. - :rtype: List of str - """ - results = [] - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - return results - - def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, - vocab_list): - """Initialize the external scorer. - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param language_model_path: Filepath for language model. If it is - empty, the external scorer will be set to - None, and the decoding method will be pure - beam search without scorer. - :type language_model_path: str|None - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - """ - # init once - if self._ext_scorer != None: - return - - if language_model_path != '': - self.logger.info("begin to initialize the external scorer " - "for decoding") - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) - self.logger.info("end initializing scorer") - else: - self._ext_scorer = None - self.logger.info("no language model provided, " - "decoding by pure beam search without scorer.") - - def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Decode by beam search for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param beam_size: Width for Beam search. - :type beam_size: int - :param cutoff_prob: Cutoff probability in pruning, - default 1.0, no pruning. - :type cutoff_prob: float - :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. - :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int - :return: List of transcription texts. - :rtype: List of str - """ - if self._ext_scorer != None: - self._ext_scorer.reset_params(beam_alpha, beam_beta) - - # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] - return results - - def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, - decoding_method): - if decoding_method == "ctc_beam_search": - self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, - vocab_list) - - def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - """ probs: activation after softmax - logits_len: audio output lens - """ - probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] - if decoding_method == "ctc_greedy": - result_transcripts = self._decode_batch_greedy( - probs_split=probs_split, vocab_list=vocab_list) - elif decoding_method == "ctc_beam_search": - result_transcripts = self._decode_batch_beam_search( - probs_split=probs_split, - beam_alpha=beam_alpha, - beam_beta=beam_beta, - beam_size=beam_size, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n, - vocab_list=vocab_list, - num_processes=num_processes) - else: - raise ValueError(f"Not support: {decoding_method}") - return result_transcripts - - @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - _, probs, logits_lens = self.predict(audio, audio_len) - return self.decode_probs(probs.numpy(), logits_lens, vocab_list, - decoding_method, lang_model_path, beam_alpha, - beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) - - def from_pretrained(self, checkpoint_path): - """Build a model from a pretrained model. - Parameters - ---------- - model: nn.Layer - Asr Model. - - checkpoint_path: Path or str - The path of pretrained model checkpoint, without extension name. - - Returns - ------- - Model - The model build from pretrined result. - """ - checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) - return - - -def ctc_loss(logits, - labels, - input_lengths, - label_lengths, - blank=0, - reduction='mean', - norm_by_times=True): - #logger.info("my ctc loss with norm by times") - ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 - loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, - input_lengths, label_lengths) - - loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) - logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") - assert reduction in ['mean', 'sum', 'none'] - if reduction == 'mean': - loss_out = paddle.mean(loss_out / label_lengths) - elif reduction == 'sum': - loss_out = paddle.sum(loss_out) - logger.info(f"ctc loss: {loss_out}") - return loss_out - - -F.ctc_loss = ctc_loss - - -class DeepSpeech2Loss(nn.Layer): - def __init__(self, vocab_size): - super().__init__() - # last token id as blank id - self.loss = nn.CTCLoss(blank=vocab_size, reduction='sum') - - def forward(self, logits, text, logits_len, text_len): - # warp-ctc do softmax on activations - # warp-ctc need activation with shape [T, B, V + 1] - logits = logits.transpose([1, 0, 2]) - - ctc_loss = self.loss(logits, text, logits_len, text_len) - return ctc_loss diff --git a/deepspeech/decoders/swig/_init_paths.py b/deepspeech/modules/activation.py similarity index 63% rename from deepspeech/decoders/swig/_init_paths.py rename to deepspeech/modules/activation.py index c4b28c643..72c2a5e2b 100644 --- a/deepspeech/decoders/swig/_init_paths.py +++ b/deepspeech/modules/activation.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Set up paths for DS2""" -import os.path -import sys +import logging +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I -def add_path(path): - if path not in sys.path: - sys.path.insert(0, path) +logger = logging.getLogger(__name__) +__all__ = ['brelu'] -this_dir = os.path.dirname(__file__) -# Add project path to PYTHONPATH -proj_path = os.path.join(this_dir, '..') -add_path(proj_path) +def brelu(x, t_min=0.0, t_max=24.0, name=None): + t_min = paddle.to_tensor(t_min) + t_max = paddle.to_tensor(t_max) + return x.maximum(t_min).minimum(t_max) diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py new file mode 100644 index 000000000..7d64c963d --- /dev/null +++ b/deepspeech/modules/conv.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.mask import sequence_mask +from deepspeech.modules.activation import brelu + +logger = logging.getLogger(__name__) + +__all__ = ['ConvStack'] + + +class ConvBn(nn.Layer): + """Convolution layer with batch normalization. + + :param kernel_size: The x dimension of a filter kernel. Or input a tuple for + two image dimension. + :type kernel_size: int|tuple|list + :param num_channels_in: Number of input channels. + :type num_channels_in: int + :param num_channels_out: Number of output channels. + :type num_channels_out: int + :param stride: The x dimension of the stride. Or input a tuple for two + image dimension. + :type stride: int|tuple|list + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension. + :type padding: int|tuple|list + :param act: Activation type, relu|brelu + :type act: string + :return: Batch norm layer after convolution layer. + :rtype: Variable + + """ + + def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, + padding, act): + + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.conv = nn.Conv2D( + num_channels_in, + num_channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_attr=None, + bias_attr=False, + data_format='NCHW') + + self.bn = nn.BatchNorm2D( + num_channels_out, + weight_attr=None, + bias_attr=None, + data_format='NCHW') + self.act = F.relu if act == 'relu' else brelu + + def forward(self, x, x_len): + """ + x(Tensor): audio, shape [B, C, D, T] + """ + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] + ) // self.stride[1] + 1 + + # reset padding part to 0 + masks = sequence_mask(x_len) #[B, T] + masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] + x = x.multiply(masks) + + return x, x_len + + +class ConvStack(nn.Layer): + """Convolution group with stacked convolution layers. + + :param feat_size: audio feature dim. + :type feat_size: int + :param num_stacks: Number of stacked convolution layers. + :type num_stacks: int + """ + + def __init__(self, feat_size, num_stacks): + super().__init__() + self.feat_size = feat_size # D + self.num_stacks = num_stacks + + self.conv_in = ConvBn( + num_channels_in=1, + num_channels_out=32, + kernel_size=(41, 11), #[D, T] + stride=(2, 3), + padding=(20, 5), + act='brelu') + + out_channel = 32 + self.conv_stack = nn.LayerList([ + ConvBn( + num_channels_in=32, + num_channels_out=out_channel, + kernel_size=(21, 11), + stride=(2, 1), + padding=(10, 5), + act='brelu') for i in range(num_stacks - 1) + ]) + + # conv output feat_dim + output_height = (feat_size - 1) // 2 + 1 + for i in range(self.num_stacks - 1): + output_height = (output_height - 1) // 2 + 1 + self.output_height = out_channel * output_height + + def forward(self, x, x_len): + """ + x: shape [B, C, D, T] + x_len : shape [B] + """ + x, x_len = self.conv_in(x, x_len) + for i, conv in enumerate(self.conv_stack): + x, x_len = conv(x, x_len) + return x, x_len diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py new file mode 100644 index 000000000..cb036c141 --- /dev/null +++ b/deepspeech/modules/mask.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ['sequence_mask'] + + +def sequence_mask(x_len, max_len=None, dtype='float32'): + max_len = max_len or x_len.max() + x_len = paddle.unsqueeze(x_len, -1) + row_vector = paddle.arange(max_len) + #mask = row_vector < x_len + mask = row_vector > x_len # a bug, broadcast 的时候出错了 + mask = paddle.cast(mask, dtype) + return mask diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py new file mode 100644 index 000000000..8902fae13 --- /dev/null +++ b/deepspeech/modules/rnn.py @@ -0,0 +1,309 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.mask import sequence_mask +from deepspeech.modules.activation import brelu + +logger = logging.getLogger(__name__) + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + #self._activation = F.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer parameters. + :type name: string + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size, h_size, share_weights): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size, h_size, act): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights): + super().__init__() + self.rnn_stacks = nn.LayerList() + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + self.rnn_stacks.append( + BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu")) + else: + self.rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + def forward(self, x, x_len): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = sequence_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + x = x.multiply(masks) + return x, x_len diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py new file mode 100644 index 000000000..7292d7a21 --- /dev/null +++ b/deepspeech/training/gradclip.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid import layers +from paddle.fluid import core + +logger = logging.getLogger(__name__) + + +class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm): + def __init__(self, clip_norm): + super().__init__(clip_norm) + + @imperative_base.no_grad + def _dygraph_clip(self, params_grads): + params_and_grads = [] + sum_square_list = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.merge_selected_rows(g) + merge_grad = layers.get_tensor_from_selected_rows(merge_grad) + square = layers.square(merge_grad) + sum_square = layers.reduce_sum(square) + logger.info( + f"Grad Before Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }" + ) + sum_square_list.append(sum_square) + + # all parameters have been filterd out + if len(sum_square_list) == 0: + return params_grads + + global_norm_var = layers.concat(sum_square_list) + global_norm_var = layers.reduce_sum(global_norm_var) + global_norm_var = layers.sqrt(global_norm_var) + logger.info(f"Grad Global Norm: {float(global_norm_var)}!!!!") + max_global_norm = layers.fill_constant( + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) + clip_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max(x=global_norm_var, y=max_global_norm)) + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + params_and_grads.append((p, g)) + continue + new_grad = layers.elementwise_mul(x=g, y=clip_var) + logger.info( + f"Grad After Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }" + ) + params_and_grads.append((p, new_grad)) + + return params_and_grads diff --git a/deepspeech/training/loss.py b/deepspeech/training/loss.py new file mode 100644 index 000000000..b0e021a59 --- /dev/null +++ b/deepspeech/training/loss.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ['CTCLoss'] + + +def ctc_loss(logits, + labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean', + norm_by_times=True): + #logger.info("my ctc loss with norm by times") + ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 + loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, + input_lengths, label_lengths) + + loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) + logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") + assert reduction in ['mean', 'sum', 'none'] + if reduction == 'mean': + loss_out = paddle.mean(loss_out / label_lengths) + elif reduction == 'sum': + loss_out = paddle.sum(loss_out) + logger.info(f"ctc loss: {loss_out}") + return loss_out + + +F.ctc_loss = ctc_loss + + +class CTCLoss(nn.Layer): + def __init__(self, blank_id): + super().__init__() + # last token id as blank id + self.loss = nn.CTCLoss(blank=blank_id, reduction='sum') + + def forward(self, logits, text, logits_len, text_len): + # warp-ctc do softmax on activations + # warp-ctc need activation with shape [T, B, V + 1] + logits = logits.transpose([1, 0, 2]) + + ctc_loss = self.loss(logits, text, logits_len, text_len) + return ctc_loss diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index cd7166593..28be4db03 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -15,6 +15,8 @@ import distutils.util +__all__ = ['print_arguments', 'add_arguments', 'print_grads', 'print_params'] + def print_arguments(args): """Print argparse's arguments. @@ -55,3 +57,21 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): type=type, help=help + ' Default: %(default)s.', **kwargs) + + +def print_grads(model, logger=None): + for n, p in model.named_parameters(): + msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" + if logger: + logger.info(msg) + + +def print_params(model, logger=None): + total = 0.0 + for n, p in model.named_parameters(): + msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}" + total += np.prod(p.shape) + if logger: + logger.info(msg) + if logger: + logger.info(f"Total parameters: {total}!")