|
|
@ -39,6 +39,8 @@ from decoders.swig_wrapper import Scorer
|
|
|
|
from decoders.swig_wrapper import ctc_greedy_decoder
|
|
|
|
from decoders.swig_wrapper import ctc_greedy_decoder
|
|
|
|
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
|
|
|
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from utils.error_rate import char_errors, word_errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSpeech2Trainer(Trainer):
|
|
|
|
class DeepSpeech2Trainer(Trainer):
|
|
|
|
def __init__(self, config, args):
|
|
|
|
def __init__(self, config, args):
|
|
|
@ -263,24 +265,83 @@ class DeepSpeech2Tester(Trainer):
|
|
|
|
loss = self.criterion(logits, texts, logits_len, texts_len)
|
|
|
|
loss = self.criterion(logits, texts, logits_len, texts_len)
|
|
|
|
return loss
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def id2token(self, texts, texts_len, vocab_list):
|
|
|
|
|
|
|
|
trans = []
|
|
|
|
|
|
|
|
for text, n in zip(texts, texts_len):
|
|
|
|
|
|
|
|
n = n.numpy().item()
|
|
|
|
|
|
|
|
ids = text[:n]
|
|
|
|
|
|
|
|
trans.append(''.join([vocab_list[i] for i in ids]))
|
|
|
|
|
|
|
|
return np.array(trans)
|
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self, inputs, outputs):
|
|
|
|
def compute_metrics(self, inputs, outputs):
|
|
|
|
|
|
|
|
cfg = self.config.decoding
|
|
|
|
|
|
|
|
|
|
|
|
_, texts, _, texts_len = inputs
|
|
|
|
_, texts, _, texts_len = inputs
|
|
|
|
logits, _, logits_len = outputs
|
|
|
|
logits, probs, logits_len = outputs
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
|
|
|
errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.dataset.vocab_list
|
|
|
|
|
|
|
|
target_transcripts = self.id2token(texts, texts_len, vocab_list)
|
|
|
|
|
|
|
|
result_transcripts = self.model.decode_probs(
|
|
|
|
|
|
|
|
probs.numpy(),
|
|
|
|
|
|
|
|
vocab_list,
|
|
|
|
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
|
|
|
|
lang_model_path=cfg.lang_model_path,
|
|
|
|
|
|
|
|
beam_alpha=cfg.alpha,
|
|
|
|
|
|
|
|
beam_beta=cfg.beta,
|
|
|
|
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
|
|
|
cutoff_prob=cfg.cutoff_prob,
|
|
|
|
|
|
|
|
cutoff_top_n=cfg.cutoff_top_n,
|
|
|
|
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for target, result in zip(target_transcripts, result_transcripts):
|
|
|
|
|
|
|
|
errors, len_ref = errors_func(target, result)
|
|
|
|
|
|
|
|
errors_sum += errors
|
|
|
|
|
|
|
|
len_refs += len_ref
|
|
|
|
|
|
|
|
num_ins += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
|
|
|
|
len_refs=len_refs,
|
|
|
|
|
|
|
|
num_ins=num_ins,
|
|
|
|
|
|
|
|
error_rate=errors_sum / len_refs,
|
|
|
|
|
|
|
|
error_rate_type=cfg.error_rate_type)
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@paddle.no_grad()
|
|
|
|
@paddle.no_grad()
|
|
|
|
def test(self):
|
|
|
|
def test(self):
|
|
|
|
self.model.eval()
|
|
|
|
self.model.eval()
|
|
|
|
losses = defaultdict(list)
|
|
|
|
losses = defaultdict(list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg = self.config
|
|
|
|
|
|
|
|
# decoders only accept string encoded in utf-8
|
|
|
|
|
|
|
|
vocab_list = self.test_loader.dataset.vocab_list
|
|
|
|
|
|
|
|
self.model.init_decode(
|
|
|
|
|
|
|
|
beam_alpha=cfg.decoding.alpha,
|
|
|
|
|
|
|
|
beam_beta=cfg.decoding.beta,
|
|
|
|
|
|
|
|
lang_model_path=cfg.decoding.lang_model_path,
|
|
|
|
|
|
|
|
vocab_list=vocab_list,
|
|
|
|
|
|
|
|
decoding_method=cfg.decoding.decoding_method)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
|
|
|
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
for i, batch in enumerate(self.test_loader):
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
outputs = self.model.predict(audio, audio_len)
|
|
|
|
outputs = self.model.predict(audio, audio_len)
|
|
|
|
loss = self.compute_losses(batch, outputs)
|
|
|
|
loss = self.compute_losses(batch, outputs)
|
|
|
|
metrics = self.compute_metrics(batch, outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses['test_loss'].append(float(loss))
|
|
|
|
losses['test_loss'].append(float(loss))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics = self.compute_metrics(batch, outputs)
|
|
|
|
|
|
|
|
errors_sum += metrics['errors_sum']
|
|
|
|
|
|
|
|
len_refs += metrics['len_refs']
|
|
|
|
|
|
|
|
num_ins += metrics['num_ins']
|
|
|
|
|
|
|
|
error_rate_type = metrics['error_rate_type']
|
|
|
|
|
|
|
|
self.logger.info("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
|
|
|
(error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
# write visual log
|
|
|
|
losses = {k: np.mean(v) for k, v in losses.items()}
|
|
|
|
losses = {k: np.mean(v) for k, v in losses.items()}
|
|
|
|
|
|
|
|
|
|
|
@ -289,6 +350,8 @@ class DeepSpeech2Tester(Trainer):
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
|
|
|
|
|
|
|
|
msg += ", Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
self.logger.info(msg)
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
def setup(self):
|
|
|
@ -359,7 +422,7 @@ class DeepSpeech2Tester(Trainer):
|
|
|
|
collate_fn = SpeechCollator()
|
|
|
|
collate_fn = SpeechCollator()
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
test_dataset,
|
|
|
|
test_dataset,
|
|
|
|
batch_size=config.data.batch_size,
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=collate_fn)
|
|
|
|
collate_fn=collate_fn)
|
|
|
|