|
|
@ -294,7 +294,7 @@ class U2Trainer(Trainer):
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
train_mode=False,
|
|
|
|
train_mode=False,
|
|
|
|
sortagrad=False,
|
|
|
|
sortagrad=False,
|
|
|
|
batch_size=config.decoding.decode_batch_size,
|
|
|
|
batch_size=config.decode.decode_batch_size,
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
minibatches=0,
|
|
|
|
minibatches=0,
|
|
|
@ -313,7 +313,7 @@ class U2Trainer(Trainer):
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
train_mode=False,
|
|
|
|
train_mode=False,
|
|
|
|
sortagrad=False,
|
|
|
|
sortagrad=False,
|
|
|
|
batch_size=config.decoding.decode_batch_size,
|
|
|
|
batch_size=config.decode.decode_batch_size,
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
maxlen_in=float('inf'),
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
maxlen_out=float('inf'),
|
|
|
|
minibatches=0,
|
|
|
|
minibatches=0,
|
|
|
@ -452,7 +452,7 @@ class U2Tester(U2Trainer):
|
|
|
|
texts,
|
|
|
|
texts,
|
|
|
|
texts_len,
|
|
|
|
texts_len,
|
|
|
|
fout=None):
|
|
|
|
fout=None):
|
|
|
|
decode_config = self.config.decoding
|
|
|
|
decode_config = self.config.decode
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
errors_func = error_rate.char_errors if decode_config.error_rate_type == 'cer' else error_rate.word_errors
|
|
|
|
error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
|
|
|
@ -564,7 +564,7 @@ class U2Tester(U2Trainer):
|
|
|
|
@paddle.no_grad()
|
|
|
|
@paddle.no_grad()
|
|
|
|
def align(self):
|
|
|
|
def align(self):
|
|
|
|
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
|
|
|
|
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
|
|
|
|
self.config.decoding.decode_batch_size,
|
|
|
|
self.config.decode.decode_batch_size,
|
|
|
|
self.config.stride_ms, self.vocab_list,
|
|
|
|
self.config.stride_ms, self.vocab_list,
|
|
|
|
self.args.result_file)
|
|
|
|
self.args.result_file)
|
|
|
|
|
|
|
|
|
|
|
|