pull/578/head
Hui Zhang 4 years ago
parent 9ad706e29d
commit 8c5b8e355b

@ -362,8 +362,8 @@ class U2Tester(U2Trainer):
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.dataset.text_feature text_feature = self.test_loader.dataset.text_feature
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
@ -381,7 +381,8 @@ class U2Tester(U2Trainer):
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time()
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
@ -397,9 +398,11 @@ class U2Tester(U2Trainer):
return dict( return dict(
errors_sum=errors_sum, errors_sum=errors_sum,
len_refs=len_refs, len_refs=len_refs,
num_ins=num_ins, num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type) error_rate_type=cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
@ -410,10 +413,13 @@ class U2Tester(U2Trainer):
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout) metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
@ -421,11 +427,13 @@ class U2Tester(U2Trainer):
logger.info("Error rate [%s] (%d/?) = %f" % logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs)) (error_rate_type, num_ins, errors_sum / len_refs))
rtf = num_time / (num_frames * self.test_loader.dataset.stride_ms / 1000.0)
# logging # logging
msg = "Test: " msg = "Test: "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += ", Final error rate [%s] (%d/%d) = %f" % ( msg += "RTF: {}, ".format(rtf)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)

@ -105,6 +105,10 @@ class AudioFeaturizer(object):
# extract spectrogram # extract spectrogram
return self._compute_specgram(audio_segment) return self._compute_specgram(audio_segment)
@property
def stride_ms(self):
return self._stride_ms
@property @property
def feature_size(self): def feature_size(self):
"""audio feature size""" """audio feature size"""

@ -63,7 +63,8 @@ class SpeechFeaturizer(object):
max_freq=None, max_freq=None,
target_sample_rate=16000, target_sample_rate=16000,
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20): target_dB=-20,
dither=1.0):
self._audio_featurizer = AudioFeaturizer( self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type, specgram_type=specgram_type,
feat_dim=feat_dim, feat_dim=feat_dim,
@ -74,7 +75,8 @@ class SpeechFeaturizer(object):
max_freq=max_freq, max_freq=max_freq,
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB) target_dB=target_dB,
dither=dither)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix) spm_model_prefix)
@ -138,6 +140,15 @@ class SpeechFeaturizer(object):
""" """
return self._audio_featurizer.feature_size return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
@property @property
def text_feature(self): def text_feature(self):
"""Return the text feature object. """Return the text feature object.

@ -63,6 +63,7 @@ class ManifestDataset(Dataset):
specgram_type='linear', # 'linear', 'mfcc', 'fbank' specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate target_sample_rate=16000, # target sample rate
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
@ -123,6 +124,7 @@ class ManifestDataset(Dataset):
specgram_type=config.data.specgram_type, specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim, feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta, delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization, use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB, target_dB=config.data.target_dB,
random_seed=config.data.random_seed, random_seed=config.data.random_seed,
@ -150,6 +152,7 @@ class ManifestDataset(Dataset):
specgram_type='linear', specgram_type='linear',
feat_dim=None, feat_dim=None,
delta_delta=False, delta_delta=False,
dither=1.0,
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
random_seed=0, random_seed=0,
@ -183,13 +186,10 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
""" """
super().__init__() super().__init__()
self._max_input_len = max_input_len, self._stride_ms = stride_ms
self._min_input_len = min_input_len, self._target_sample_rate = target_sample_rate
self._max_output_len = max_output_len,
self._min_output_len = min_output_len,
self._max_output_input_ratio = max_output_input_ratio,
self._min_output_input_ratio = min_output_input_ratio,
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None
self._augmentation_pipeline = AugmentationPipeline( self._augmentation_pipeline = AugmentationPipeline(
@ -207,7 +207,8 @@ class ManifestDataset(Dataset):
max_freq=max_freq, max_freq=max_freq,
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB) target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed) self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
@ -250,6 +251,10 @@ class ManifestDataset(Dataset):
@property @property
def feature_size(self): def feature_size(self):
return self._speech_featurizer.feature_size return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def _parse_tar(self, file): def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object """Parse a tar file to get a tarfile object

@ -18,6 +18,7 @@ data:
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
dither: 1.0
target_sample_rate: 16000 target_sample_rate: 16000
max_freq: None max_freq: None
n_fft: None n_fft: None

@ -19,6 +19,7 @@ data:
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
dither: 1.0
target_sample_rate: 16000 target_sample_rate: 16000
max_freq: None max_freq: None
n_fft: None n_fft: None

Loading…
Cancel
Save