rename decoding to decode

pull/1225/head
huangyuxin 3 years ago
parent 960658f669
commit 2c5902d7c5

@ -31,6 +31,7 @@ U2Trainer.params(_C)
_C.decoding = U2Tester.params()
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered

@ -294,7 +294,7 @@ class U2Trainer(Trainer):
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.decoding.decode_batch_size,
batch_size=config.decode.decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
@ -313,7 +313,7 @@ class U2Trainer(Trainer):
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.decoding.decode_batch_size,
batch_size=config.decode.decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
@ -452,7 +452,7 @@ class U2Tester(U2Trainer):
texts,
texts_len,
fout=None):
decode_config = self.config.decoding
decode_config = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
@ -564,7 +564,7 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.decode_batch_size,
self.config.decode.decode_batch_size,
self.config.stride_ms, self.vocab_list,
self.args.result_file)

Loading…
Cancel
Save