using unk when training

compute_loss need text ids
ord id using in test mode, which compute wer/cer
pull/522/head
Hui Zhang 5 years ago
parent 5fe1b40630
commit 4292e50622

@ -428,7 +428,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
class SpeechCollator():
def __init__(self, padding_to=-1, is_training=False):
def __init__(self, padding_to=-1, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.

@ -30,6 +30,7 @@ class TextFeaturizer(object):
"""
def __init__(self, vocab_filepath):
self.unk = '<unk>'
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
@ -43,7 +44,11 @@ class TextFeaturizer(object):
:rtype: list
"""
tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens]
ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
return ids
@property
def vocab_size(self):

@ -317,7 +317,7 @@ class DeepSpeech2Trainer(Trainer):
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=True)
keep_transcription_text=False)
if self.parallel:
batch_sampler = DeepSpeech2DistributedBatchSampler(
@ -342,14 +342,14 @@ class DeepSpeech2Trainer(Trainer):
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=SpeechCollator(is_training=True),
collate_fn=collate_fn,
num_workers=config.data.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(is_training=True))
collate_fn=collate_fn)
self.logger.info("Setup train/valid Dataloader!")
@ -415,7 +415,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.logger.info(
f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
losses = defaultdict(list)
cfg = self.config
# decoders only accept string encoded in utf-8
vocab_list = self.test_loader.dataset.vocab_list
@ -432,10 +432,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
for i, batch in enumerate(self.test_loader):
audio, text, audio_len, text_len = batch
outputs = self.model.predict(audio, audio_len)
loss = self.compute_losses(batch, outputs)
losses['test_loss'].append(float(loss))
metrics = self.compute_metrics(batch, outputs)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
@ -443,14 +441,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# write visual log
losses = {k: np.mean(v) for k, v in losses.items()}
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
msg += ", Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
self.logger.info(msg)

@ -59,6 +59,7 @@ def main():
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
fout.write('<unk>' + '\n')
for char, count in count_sorted:
if count < args.count_threshold: break
fout.write(char + '\n')

Loading…
Cancel
Save