test with decoding

pull/522/head
Hui Zhang 5 years ago
parent f6eafe85f1
commit 314886d4c5

@ -39,6 +39,8 @@ from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors
class DeepSpeech2Trainer(Trainer): class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
@ -263,24 +265,83 @@ class DeepSpeech2Tester(Trainer):
loss = self.criterion(logits, texts, logits_len, texts_len) loss = self.criterion(logits, texts, logits_len, texts_len)
return loss return loss
def id2token(self, texts, texts_len, vocab_list):
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([vocab_list[i] for i in ids]))
return np.array(trans)
def compute_metrics(self, inputs, outputs): def compute_metrics(self, inputs, outputs):
cfg = self.config.decoding
_, texts, _, texts_len = inputs _, texts, _, texts_len = inputs
logits, _, logits_len = outputs logits, probs, logits_len = outputs
pass
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors
vocab_list = self.test_loader.dataset.vocab_list
target_transcripts = self.id2token(texts, texts_len, vocab_list)
result_transcripts = self.model.decode_probs(
probs.numpy(),
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def test(self): def test(self):
self.model.eval() self.model.eval()
losses = defaultdict(list) losses = defaultdict(list)
cfg = self.config
# decoders only accept string encoded in utf-8
vocab_list = self.test_loader.dataset.vocab_list
self.model.init_decode(
beam_alpha=cfg.decoding.alpha,
beam_beta=cfg.decoding.beta,
lang_model_path=cfg.decoding.lang_model_path,
vocab_list=vocab_list,
decoding_method=cfg.decoding.decoding_method)
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
audio, text, audio_len, text_len = batch audio, text, audio_len, text_len = batch
outputs = self.model.predict(audio, audio_len) outputs = self.model.predict(audio, audio_len)
loss = self.compute_losses(batch, outputs) loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs)
losses['test_loss'].append(float(loss)) losses['test_loss'].append(float(loss))
metrics = self.compute_metrics(batch, outputs)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
self.logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# write visual log # write visual log
losses = {k: np.mean(v) for k, v in losses.items()} losses = {k: np.mean(v) for k, v in losses.items()}
@ -289,6 +350,8 @@ class DeepSpeech2Tester(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items()) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
msg += ", Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
self.logger.info(msg) self.logger.info(msg)
def setup(self): def setup(self):
@ -359,7 +422,7 @@ class DeepSpeech2Tester(Trainer):
collate_fn = SpeechCollator() collate_fn = SpeechCollator()
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.data.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn)

@ -646,11 +646,10 @@ class DeepSpeech2(nn.Layer):
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
@paddle.no_grad() def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path,
def decode(self, audio, audio_len, vocab_list, decoding_method, beam_alpha, beam_beta, beam_size, cutoff_prob,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes):
cutoff_top_n, num_processes): """ probs: activation after softmax """
_, probs, _ = self.predict(audio, audio_len)
if decoding_method == "ctc_greedy": if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy( result_transcripts = self._decode_batch_greedy(
probs_split=probs, vocab_list=vocab_list) probs_split=probs, vocab_list=vocab_list)
@ -668,6 +667,15 @@ class DeepSpeech2(nn.Layer):
raise ValueError(f"Not support: {decoding_method}") raise ValueError(f"Not support: {decoding_method}")
return result_transcripts return result_transcripts
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
_, probs, _ = self.predict(audio, audio_len)
return self.decode_probs(
probs, vocab_list, decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes)
class DeepSpeech2Loss(nn.Layer): class DeepSpeech2Loss(nn.Layer):
def __init__(self, vocab_size): def __init__(self, vocab_size):

Loading…
Cancel
Save