pull/578/head
Hui Zhang 4 years ago
parent b15b6c6a26
commit 2aed275233

@ -357,9 +357,9 @@ if not hasattr(paddle.Tensor, 'tolist'):
########### hcak paddle.nn.functional ############# ########### hcak paddle.nn.functional #############
def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor: def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation.""" """The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=dim) a, b = x.split(2, axis=axis)
act_b = F.sigmoid(b) act_b = F.sigmoid(b)
return a * act_b return a * act_b

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
import os
from deepspeech.exps.u2.config import get_cfg_defaults from deepspeech.exps.u2.config import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester from deepspeech.exps.u2.model import U2Tester as Tester
@ -53,4 +52,4 @@ if __name__ == "__main__":
# Setting for profiling # Setting for profiling
pr = cProfile.Profile() pr = cProfile.Profile()
pr.runcall(main, config, args) pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile')) pr.dump_stats('test.profile')

@ -91,7 +91,7 @@ training:
decoding: decoding:
batch_size: 128 batch_size: 128
error_rate_type: cer error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5 alpha: 2.5
beta: 0.3 beta: 0.3

Loading…
Cancel
Save