Merge pull request #802 from PaddlePaddle/fix_ds2_bw_bug

fix the bug of sharing cell in BiGRU and BIRNN
pull/808/head
Jackwaterveg 3 years ago committed by GitHub
commit 38174c7055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,7 +18,7 @@
All tested under:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
* paddlepaddle>=2.2.0rc
Please see [install](doc/src/install.md).

@ -20,7 +20,7 @@
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
* paddlepaddle>=2.2.0rc
参看 [安装](doc/src/install.md)。

@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]

@ -22,6 +22,13 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
try:
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
except Exception as e:
logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder']
@ -216,9 +223,6 @@ class CTCDecoder(nn.Layer):
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method):
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,

@ -10,8 +10,11 @@
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |

@ -5,8 +5,8 @@ if [ $# != 3 ]; then
exit -1
fi
ckpt_dir=${1}
avg_mode=${2} # best,latest
avg_mode=${1} # best,latest
ckpt_dir=${2}
average_num=${3}
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams

Loading…
Cancel
Save