Merge pull request #802 from PaddlePaddle/fix_ds2_bw_bug

fix the bug of sharing cell in BiGRU and BIRNN
pull/808/head
Jackwaterveg 3 years ago committed by GitHub
commit 38174c7055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,7 +18,7 @@
All tested under: All tested under:
* Ubuntu 16.04 * Ubuntu 16.04
* python>=3.7 * python>=3.7
* paddlepaddle>=2.1.2 * paddlepaddle>=2.2.0rc
Please see [install](doc/src/install.md). Please see [install](doc/src/install.md).

@ -20,7 +20,7 @@
* Ubuntu 16.04 * Ubuntu 16.04
* python>=3.7 * python>=3.7
* paddlepaddle>=2.1.2 * paddlepaddle>=2.2.0rc
参看 [安装](doc/src/install.md)。 参看 [安装](doc/src/install.md)。

@ -29,13 +29,13 @@ __all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase): class RNNCell(nn.RNNCellBase):
r""" r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states. computes the outputs and updates states.
The formula used is as follows: The formula used is as follows:
.. math:: .. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`. where :math:`act` is for :attr:`activation`.
""" """
@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase):
class GRUCell(nn.RNNCellBase): class GRUCell(nn.RNNCellBase):
r""" r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states. it computes the outputs and updates states.
The formula for GRU used is as follows: The formula for GRU used is as follows:
.. math:: .. math::
@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase):
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
""" """
@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
self.fw_rnn = nn.RNN( self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN( self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D] # x, shape [B, T, D]
@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
self.fw_rnn = nn.RNN( self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN( self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len): def forward(self, x, x_len):
# x, shape [B, T, D] # x, shape [B, T, D]

@ -22,6 +22,13 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
try:
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
except Exception as e:
logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder'] __all__ = ['CTCDecoder']
@ -216,9 +223,6 @@ class CTCDecoder(nn.Layer):
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method): decoding_method):
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
if decoding_method == "ctc_beam_search": if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,

@ -10,8 +10,11 @@
| Model | Params | Release | Config | Test set | Loss | CER | | Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | | DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | | DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |

@ -5,8 +5,8 @@ if [ $# != 3 ]; then
exit -1 exit -1
fi fi
ckpt_dir=${1} avg_mode=${1} # best,latest
avg_mode=${2} # best,latest ckpt_dir=${2}
average_num=${3} average_num=${3}
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams

Loading…
Cancel
Save