diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index b58260749..ffe678a69 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -24,17 +24,14 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I -from deepspeech.modules.conv import ConvStack -from deepspeech.modules.rnn import RNNStack from deepspeech.modules.mask import sequence_mask from deepspeech.modules.activation import brelu +from deepspeech.modules.conv import ConvStack +from deepspeech.modules.rnn import RNNStack +from deepspeech.modules.ctc import CTCDecoder + from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools -from deepspeech.decoders.swig_wrapper import Scorer -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch - -from deepspeech.modules.loss import CTCLoss logger = logging.getLogger(__name__) @@ -105,178 +102,6 @@ class CRNNEncoder(nn.Layer): return x, x_lens -class CTCDecoder(nn.Layer): - def __init__(self, enc_n_units, vocab_size): - super().__init__() - self.blank_id = vocab_size - self.output = nn.Linear(enc_n_units, - vocab_size + 1) # blank id is last id - self.criterion = CTCLoss(self.blank_id) - - self._ext_scorer = None - - def forward(self, eout, eout_lens, texts, texts_len): - """Compute CTC Loss - - Args: - eout (Tensor): - eout_lens (Tensor): - texts (Tenosr): - texts_len (Tensor): - Returns: - loss (Tenosr): [1] - """ - logits = self.output(eout) - loss = self.criterion(logits, texts, eout_lens, texts_len) - return loss - - def probs(self, eouts, temperature=1.): - """Get CTC probabilities. - Args: - eouts (FloatTensor): `[B, T, enc_units]` - Returns: - probs (FloatTensor): `[B, T, vocab]` - """ - return F.softmax(self.output(eouts) / temperature, axis=-1) - - def scores(self, eouts, temperature=1.): - """Get log-scale CTC probabilities. - Args: - eouts (FloatTensor): `[B, T, enc_units]` - Returns: - log_probs (FloatTensor): `[B, T, vocab]` - """ - return F.log_softmax(self.output(eouts) / temperature, axis=-1) - - def _decode_batch_greedy(self, probs_split, vocab_list): - """Decode by best path for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :return: List of transcription texts. - :rtype: List of str - """ - results = [] - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - return results - - def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, - vocab_list): - """Initialize the external scorer. - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param language_model_path: Filepath for language model. If it is - empty, the external scorer will be set to - None, and the decoding method will be pure - beam search without scorer. - :type language_model_path: str|None - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - """ - # init once - if self._ext_scorer != None: - return - - if language_model_path != '': - logger.info("begin to initialize the external scorer " - "for decoding") - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + " dict_size = %d" % - lm_dict_size) - logger.info("end initializing scorer") - else: - self._ext_scorer = None - logger.info("no language model provided, " - "decoding by pure beam search without scorer.") - - def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Decode by beam search for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param beam_size: Width for Beam search. - :type beam_size: int - :param cutoff_prob: Cutoff probability in pruning, - default 1.0, no pruning. - :type cutoff_prob: float - :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. - :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int - :return: List of transcription texts. - :rtype: List of str - """ - if self._ext_scorer != None: - self._ext_scorer.reset_params(beam_alpha, beam_beta) - - # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] - return results - - def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, - decoding_method): - if decoding_method == "ctc_beam_search": - self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, - vocab_list) - - def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - """ probs: activation after softmax - logits_len: audio output lens - """ - probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] - if decoding_method == "ctc_greedy": - result_transcripts = self._decode_batch_greedy( - probs_split=probs_split, vocab_list=vocab_list) - elif decoding_method == "ctc_beam_search": - result_transcripts = self._decode_batch_beam_search( - probs_split=probs_split, - beam_alpha=beam_alpha, - beam_beta=beam_beta, - beam_size=beam_size, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n, - vocab_list=vocab_list, - num_processes=num_processes) - else: - raise ValueError(f"Not support: {decoding_method}") - return result_transcripts - - class DeepSpeech2Model(nn.Layer): """The DeepSpeech2 network structure. @@ -339,8 +164,13 @@ class DeepSpeech2Model(nn.Layer): use_gru=use_gru, share_rnn_weights=share_rnn_weights) assert (self.encoder.output_size == rnn_size * 2) + self.decoder = CTCDecoder( - enc_n_units=self.encoder.output_size, vocab_size=dict_size) + enc_n_units=self.encoder.output_size, + odim=dict_size + 1, # is append after vocab + blank_id=dict_size, # last token is + dropout_rate=0.0, + reduction=True) def forward(self, audio, text, audio_len, text_len): """Compute Model loss diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 14861fcf7..a42bd1e74 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -14,6 +14,7 @@ import logging import numpy as np +import math import paddle from paddle import nn @@ -22,7 +23,7 @@ from paddle.nn import initializer as I logger = logging.getLogger(__name__) -__all__ = ['brelu'] +__all__ = ['brelu', "softplus", "gelu_accurate", "gelu", 'Swish'] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,3 +31,38 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32') t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32') return x.maximum(t_min).minimum(t_max) + + +def softplus(x): + """Softplus function.""" + if hasattr(paddle.nn.functional, 'softplus'): + #return paddle.nn.functional.softplus(x.float()).type_as(x) + return paddle.nn.functional.softplus(x) + else: + raise NotImplementedError + + +def gelu_accurate(x): + """Gaussian Error Linear Units (GELU) activation.""" + # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * + (x + 0.044715 * paddle.pow(x, 3)))) + + +def gelu(x): + """Gaussian Error Linear Units (GELU) activation.""" + if hasattr(torch.nn.functional, 'gelu'): + #return torch.nn.functional.gelu(x.float()).type_as(x) + return torch.nn.functional.gelu(x) + else: + return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) + + +class Swish(nn.Layer): + """Construct an Swish object.""" + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """Return Swish activation function.""" + return x * F.sigmoid(x) diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py new file mode 100644 index 000000000..66737f599 --- /dev/null +++ b/deepspeech/modules/ctc.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typeguard import check_argument_types + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.decoders.swig_wrapper import Scorer +from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder +from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch +from deepspeech.modules.loss import CTCLoss + +logger = logging.getLogger(__name__) + +__all__ = ['CTCDecoder'] + + +class CTCDecoder(nn.Layer): + def __init__(self, + enc_n_units, + odim, + blank_id=0, + dropout_rate: float=0.0, + reduction: bool=True): + """CTC decoder + + Args: + enc_n_units ([int]): encoder output dimention + vocab_size ([int]): text vocabulary size + dropout_rate (float): dropout rate (0.0 ~ 1.0) + reduction (bool): reduce the CTC loss into a scalar + """ + assert check_argument_types() + super().__init__() + + self.blank_id = blank_id + self.odim = odim + self.dropout_rate = dropout_rate + self.ctc_lo = nn.Linear(enc_n_units, self.odim) + reduction_type = "sum" if reduction else "none" + self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) + + # CTCDecoder LM Score handle + self._ext_scorer = None + + def forward(self, hs_pad, hlens, ys_pad, ys_lens): + """Calculate CTC loss. + + Args: + hs_pad (Tensor): batch of padded hidden state sequences (B, Tmax, D) + hlens (Tensor): batch of lengths of hidden state sequences (B) + ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax) + ys_lens (Tensor): batch of lengths of character sequence (B) + Returns: + loss (Tenosr): scalar. + """ + logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + loss = self.criterion(logits, ys_pad, hlens, ys_lens) + return loss + + def probs(self, eouts: paddle.Tensor, temperature: float=1.0): + """Get CTC probabilities. + Args: + eouts (FloatTensor): `[B, T, enc_units]` + Returns: + probs (FloatTensor): `[B, T, odim]` + """ + return F.softmax(self.ctc_lo(eouts) / temperature, axis=-1) + + def scores(self, eouts: paddle.Tensor, temperature: float=1.0): + """Get log-scale CTC probabilities. + Args: + eouts (FloatTensor): `[B, T, enc_units]` + Returns: + log_probs (FloatTensor): `[B, T, odim]` + """ + return F.log_softmax(self.ctc_lo(eouts) / temperature, axis=-1) + + def log_softmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: + """log_softmax of frame activations + Args: + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim) + """ + return self.scores(hs_pad) + + def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor: + """argmax of frame activations + Args: + paddle.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + paddle.Tensor: argmax applied 2d tensor (B, Tmax) + """ + return paddle.argmax(self.ctc_lo(hs_pad), dim=2) + + def _decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :return: List of transcription texts. + :rtype: List of str + """ + results = [] + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=vocab_list) + results.append(output_transcription) + return results + + def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, + vocab_list): + """Initialize the external scorer. + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param language_model_path: Filepath for language model. If it is + empty, the external scorer will be set to + None, and the decoding method will be pure + beam search without scorer. + :type language_model_path: str|None + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + """ + # init once + if self._ext_scorer != None: + return + + if language_model_path != '': + logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + " dict_size = %d" % + lm_dict_size) + logger.info("end initializing scorer") + else: + self._ext_scorer = None + logger.info("no language model provided, " + "decoding by pure beam search without scorer.") + + def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of str + """ + if self._ext_scorer != None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + + # beam search decode + num_processes = min(num_processes, len(probs_split)) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) + + results = [result[0][1] for result in beam_search_results] + return results + + def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, + decoding_method): + if decoding_method == "ctc_beam_search": + self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, + vocab_list) + + def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + """ probs: activation after softmax + logits_len: audio output lens + """ + probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] + if decoding_method == "ctc_greedy": + result_transcripts = self._decode_batch_greedy( + probs_split=probs_split, vocab_list=vocab_list) + elif decoding_method == "ctc_beam_search": + result_transcripts = self._decode_batch_beam_search( + probs_split=probs_split, + beam_alpha=beam_alpha, + beam_beta=beam_beta, + beam_size=beam_size, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n, + vocab_list=vocab_list, + num_processes=num_processes) + else: + raise ValueError(f"Not support: {decoding_method}") + return result_transcripts diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py new file mode 100644 index 000000000..0f3ddef6c --- /dev/null +++ b/deepspeech/modules/embedding.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Positonal Encoding Module.""" + +import math +import logging +import numpy as np +from typing import Tuple + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ["PositionalEncoding", "RelPositionalEncoding"] + +# TODO(Hui Zhang): remove this hack +paddle.float32 = 'float32' + + +class PositionalEncoding(nn.Layer): + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int=5000, + reverse: bool=False): + """Positional encoding. + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + Args: + d_model (int): embedding dim. + dropout_rate (float): dropout rate. + max_len (int, optional): maximum input length. Defaults to 5000. + reverse (bool, optional): Not used. Defaults to False. + """ + super().__init__() + self.d_model = d_model + self.max_len = max_len + self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) + self.dropout = nn.Dropout(p=dropout_rate) + self.pe = paddle.zeros(self.max_len, self.d_model) #[T,D] + + position = paddle.arange( + 0, self.max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + -(math.log(10000.0) / self.d_model)) + + self.pe[:, 0::2] = paddle.sin(position * div_term) + self.pe[:, 1::2] = paddle.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) #[1, T, D] + + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add positional encoding. + Args: + x (paddle.Tensor): Input. Its shape is (batch, time, ...) + offset (int): position offset + Returns: + paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) + paddle.Tensor: for compatibility to RelPositionalEncoding + """ + T = paddle.shape(x)[1] + assert offset + T < self.max_len + #assert offset + x.size(1) < self.max_len + #self.pe = self.pe.to(x.device) + #pos_emb = self.pe[:, offset:offset + x.size(1)] + pos_emb = self.pe[:, offset:offset + T] + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + """ For getting encoding in a streaming fashion + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + Args: + offset (int): start offset + size (int): requried size of position encoding + Returns: + paddle.Tensor: Corresponding encoding + """ + assert offset + size < self.max_len + return self.dropout(self.pe[:, offset:offset + size]) + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): + """ + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int, optional): [Maximum input length.]. Defaults to 5000. + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute positional encoding. + Args: + x (paddle.Tensor): Input tensor (batch, time, `*`). + Returns: + paddle.Tensor: Encoded tensor (batch, time, `*`). + paddle.Tensor: Positional embedding tensor (1, time, `*`). + """ + T = paddle.shape()[1] + assert offset + T < self.max_len + #assert offset + x.size(1) < self.max_len + #self.pe = self.pe.to(x.device) + x = x * self.xscale + #pos_emb = self.pe[:, offset:offset + x.size(1)] + pos_emb = self.pe[:, offset:offset + T] + return self.dropout(x), self.dropout(pos_emb) diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index b0e021a59..7163baf2f 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) __all__ = ['CTCLoss'] +# TODO(Hui Zhang): remove this hack, when `norm_by_times=True` is added def ctc_loss(logits, labels, input_lengths, @@ -47,19 +48,35 @@ def ctc_loss(logits, return loss_out +# TODO(Hui Zhang): remove this hack F.ctc_loss = ctc_loss class CTCLoss(nn.Layer): - def __init__(self, blank_id): + def __init__(self, blank=0, reduction='sum'): super().__init__() # last token id as blank id - self.loss = nn.CTCLoss(blank=blank_id, reduction='sum') + self.loss = nn.CTCLoss(blank=blank, reduction=reduction) - def forward(self, logits, text, logits_len, text_len): + def forward(self, logits, ys_pad, hlens, ys_lens): + """Compute CTC loss. + + Args: + logits ([paddle.Tensor]): [description] + ys_pad ([paddle.Tensor]): [description] + hlens ([paddle.Tensor]): [description] + ys_lens ([paddle.Tensor]): [description] + + Returns: + [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. + """ # warp-ctc do softmax on activations # warp-ctc need activation with shape [T, B, V + 1] + # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) + loss = self.loss(logits, ys_pad, hlens, ys_lens) - ctc_loss = self.loss(logits, text, logits_len, text_len) - return ctc_loss + # wenet do batch-size average, deepspeech2 not do this + # Batch-size average + # loss = loss / paddle.shape(logits)[1] + return loss diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index cb036c141..0f136403f 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -28,6 +28,7 @@ def sequence_mask(x_len, max_len=None, dtype='float32'): max_len = max_len or x_len.max() x_len = paddle.unsqueeze(x_len, -1) row_vector = paddle.arange(max_len) + # TODO(Hui Zhang): fix this bug #mask = row_vector < x_len mask = row_vector > x_len # a bug, broadcast 的时候出错了 mask = paddle.cast(mask, dtype) diff --git a/requirements.txt b/requirements.txt index 8ab09f626..14d7c0325 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ SoundFile==0.9.0.post1 python_speech_features tensorboardX yacs +typeguard