From 4292e5062255acb1daa9f59e38b05269e4a540d2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 22 Feb 2021 12:36:59 +0000 Subject: [PATCH] using unk when training compute_loss need text ids ord id using in test mode, which compute wer/cer --- data_utils/dataset.py | 2 +- data_utils/featurizer/text_featurizer.py | 7 ++++++- model_utils/model.py | 16 +++++----------- tools/build_vocab.py | 1 + 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 6b9b9aecc..6be0c0455 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -428,7 +428,7 @@ class DeepSpeech2BatchSampler(BatchSampler): class SpeechCollator(): - def __init__(self, padding_to=-1, is_training=False): + def __init__(self, padding_to=-1, is_training=True): """ Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py index 70aa10ead..a1e8cdbb1 100644 --- a/data_utils/featurizer/text_featurizer.py +++ b/data_utils/featurizer/text_featurizer.py @@ -30,6 +30,7 @@ class TextFeaturizer(object): """ def __init__(self, vocab_filepath): + self.unk = '' self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) @@ -43,7 +44,11 @@ class TextFeaturizer(object): :rtype: list """ tokens = self._char_tokenize(text) - return [self._vocab_dict[token] for token in tokens] + ids = [] + for token in tokens: + token = token if token in self._vocab_dict else self.unk + ids.append(self._vocab_dict[token]) + return ids @property def vocab_size(self): diff --git a/model_utils/model.py b/model_utils/model.py index 9edae9da6..b60e87883 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -317,7 +317,7 @@ class DeepSpeech2Trainer(Trainer): use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, - keep_transcription_text=True) + keep_transcription_text=False) if self.parallel: batch_sampler = DeepSpeech2DistributedBatchSampler( @@ -342,14 +342,14 @@ class DeepSpeech2Trainer(Trainer): self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=SpeechCollator(is_training=True), + collate_fn=collate_fn, num_workers=config.data.num_workers, ) self.valid_loader = DataLoader( dev_dataset, batch_size=config.data.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator(is_training=True)) + collate_fn=collate_fn) self.logger.info("Setup train/valid Dataloader!") @@ -415,7 +415,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.logger.info( f"Test Total Examples: {len(self.test_loader.dataset)}") self.model.eval() - losses = defaultdict(list) + cfg = self.config # decoders only accept string encoded in utf-8 vocab_list = self.test_loader.dataset.vocab_list @@ -432,10 +432,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): for i, batch in enumerate(self.test_loader): audio, text, audio_len, text_len = batch outputs = self.model.predict(audio, audio_len) - loss = self.compute_losses(batch, outputs) - losses['test_loss'].append(float(loss)) - metrics = self.compute_metrics(batch, outputs) + errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] @@ -443,14 +441,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.logger.info("Error rate [%s] (%d/?) = %f" % (error_rate_type, num_ins, errors_sum / len_refs)) - # write visual log - losses = {k: np.mean(v) for k, v in losses.items()} - # logging msg = "Test: " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items()) msg += ", Final error rate [%s] (%d/%d) = %f" % ( error_rate_type, num_ins, num_ins, errors_sum / len_refs) self.logger.info(msg) diff --git a/tools/build_vocab.py b/tools/build_vocab.py index 77fd1fb63..2e47e84e5 100644 --- a/tools/build_vocab.py +++ b/tools/build_vocab.py @@ -59,6 +59,7 @@ def main(): count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) with codecs.open(args.vocab_path, 'w', 'utf-8') as fout: + fout.write('' + '\n') for char, count in count_sorted: if count < args.count_threshold: break fout.write(char + '\n')