Give option to disable converting from transcription text to ids.

pull/2/head
yangyaming 7 years ago
parent f23282dc36
commit 39dbcb4dfb

@ -55,6 +55,10 @@ class DataGenerator(object):
:type num_threads: int :type num_threads: int
:param random_seed: Random seed. :param random_seed: Random seed.
:type random_seed: int :type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
""" """
def __init__(self, def __init__(self,
@ -69,7 +73,8 @@ class DataGenerator(object):
specgram_type='linear', specgram_type='linear',
use_dB_normalization=True, use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2, num_threads=multiprocessing.cpu_count() // 2,
random_seed=0): random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration self._max_duration = max_duration
self._min_duration = min_duration self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath) self._normalizer = FeatureNormalizer(mean_std_filepath)
@ -84,6 +89,7 @@ class DataGenerator(object):
use_dB_normalization=use_dB_normalization) use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads self._num_threads = num_threads
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0 self._epoch = 0
# for caching tar files info # for caching tar files info
self._local_data = local() self._local_data = local()
@ -107,9 +113,10 @@ class DataGenerator(object):
else: else:
speech_segment = SpeechSegment.from_file(filename, transcript) speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment) specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
return specgram, text_ids return specgram, transcript_part
def batch_reader_creator(self, def batch_reader_creator(self,
manifest_path, manifest_path,

@ -60,12 +60,12 @@ class SpeechFeaturizer(object):
target_dB=target_dB) target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath) self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment): def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment. """Extract features for speech segment.
1. For audio parts, extract the audio features. 1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices 2. For transcript parts, keep the original text or convert text string
in char-level. to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from. :param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment :type audio_segment: SpeechSegment
@ -74,6 +74,8 @@ class SpeechFeaturizer(object):
:rtype: tuple :rtype: tuple
""" """
audio_feature = self._audio_featurizer.featurize(speech_segment) audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript) text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids return audio_feature, text_ids

@ -146,7 +146,8 @@ def start_server():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1,
keep_transcription_text=True)
# prepare ASR model # prepare ASR model
ds2_model = DeepSpeech2Model( ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size, vocab_size=data_generator.vocab_size,

@ -68,7 +68,8 @@ def infer():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=1) num_threads=1,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest, manifest_path=args.infer_manifest,
batch_size=args.num_samples, batch_size=args.num_samples,
@ -103,8 +104,7 @@ def infer():
error_rate_func = cer if args.error_rate_type == 'cer' else wer error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) transcript for _, transcript in infer_data
for _, transcript in infer_data
] ]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" % print("\nTarget Transcription: %s\nOutput Transcription: %s" %

@ -69,7 +69,8 @@ def evaluate():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=args.num_proc_data) num_threads=args.num_proc_data,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest, manifest_path=args.test_manifest,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -104,8 +105,7 @@ def evaluate():
language_model_path=args.lang_model_path, language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch) num_processes=args.num_proc_bsearch)
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) transcript for _, transcript in infer_data
for _, transcript in infer_data
] ]
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result) error_sum += error_rate_func(target, result)

@ -87,7 +87,8 @@ def tune():
mean_std_filepath=args.mean_std_path, mean_std_filepath=args.mean_std_path,
augmentation_config='{}', augmentation_config='{}',
specgram_type=args.specgram_type, specgram_type=args.specgram_type,
num_threads=args.num_proc_data) num_threads=args.num_proc_data,
keep_transcription_text=True)
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram",
@ -164,8 +165,7 @@ def tune():
] ]
target_transcripts = [ target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript]) transcript for _, transcript in infer_data
for _, transcript in infer_data
] ]
num_ins += len(target_transcripts) num_ins += len(target_transcripts)

Loading…
Cancel
Save