using unk when training

compute_loss need text ids
ord id using in test mode, which compute wer/cer
pull/522/head
Hui Zhang 5 years ago
parent 5fe1b40630
commit 4292e50622

@ -428,7 +428,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
class SpeechCollator(): class SpeechCollator():
def __init__(self, padding_to=-1, is_training=False): def __init__(self, padding_to=-1, is_training=True):
""" """
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach. a user-defined shape) within one bach.

@ -30,6 +30,7 @@ class TextFeaturizer(object):
""" """
def __init__(self, vocab_filepath): def __init__(self, vocab_filepath):
self.unk = '<unk>'
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath) vocab_filepath)
@ -43,7 +44,11 @@ class TextFeaturizer(object):
:rtype: list :rtype: list
""" """
tokens = self._char_tokenize(text) tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens] ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
ids.append(self._vocab_dict[token])
return ids
@property @property
def vocab_size(self): def vocab_size(self):

@ -317,7 +317,7 @@ class DeepSpeech2Trainer(Trainer):
use_dB_normalization=config.data.use_dB_normalization, use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB, target_dB=config.data.target_dB,
random_seed=config.data.random_seed, random_seed=config.data.random_seed,
keep_transcription_text=True) keep_transcription_text=False)
if self.parallel: if self.parallel:
batch_sampler = DeepSpeech2DistributedBatchSampler( batch_sampler = DeepSpeech2DistributedBatchSampler(
@ -342,14 +342,14 @@ class DeepSpeech2Trainer(Trainer):
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=SpeechCollator(is_training=True), collate_fn=collate_fn,
num_workers=config.data.num_workers, ) num_workers=config.data.num_workers, )
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(is_training=True)) collate_fn=collate_fn)
self.logger.info("Setup train/valid Dataloader!") self.logger.info("Setup train/valid Dataloader!")
@ -415,7 +415,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.logger.info( self.logger.info(
f"Test Total Examples: {len(self.test_loader.dataset)}") f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval() self.model.eval()
losses = defaultdict(list)
cfg = self.config cfg = self.config
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = self.test_loader.dataset.vocab_list vocab_list = self.test_loader.dataset.vocab_list
@ -432,10 +432,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
audio, text, audio_len, text_len = batch audio, text, audio_len, text_len = batch
outputs = self.model.predict(audio, audio_len) outputs = self.model.predict(audio, audio_len)
loss = self.compute_losses(batch, outputs)
losses['test_loss'].append(float(loss))
metrics = self.compute_metrics(batch, outputs) metrics = self.compute_metrics(batch, outputs)
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
@ -443,14 +441,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.logger.info("Error rate [%s] (%d/?) = %f" % self.logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs)) (error_rate_type, num_ins, errors_sum / len_refs))
# write visual log
losses = {k: np.mean(v) for k, v in losses.items()}
# logging # logging
msg = "Test: " msg = "Test: "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
msg += ", Final error rate [%s] (%d/%d) = %f" % ( msg += ", Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
self.logger.info(msg) self.logger.info(msg)

@ -59,6 +59,7 @@ def main():
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with codecs.open(args.vocab_path, 'w', 'utf-8') as fout: with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
fout.write('<unk>' + '\n')
for char, count in count_sorted: for char, count in count_sorted:
if count < args.count_threshold: break if count < args.count_threshold: break
fout.write(char + '\n') fout.write(char + '\n')

Loading…
Cancel
Save